mirror of
https://github.com/bkaradzic/bgfx.git
synced 2026-02-17 12:42:34 +01:00
Updated spirv-tools.
This commit is contained in:
@@ -1 +1 @@
|
||||
"v2025.2", "SPIRV-Tools v2025.2 v2025.2.rc2-58-g007a1f89"
|
||||
"v2025.3", "SPIRV-Tools v2025.3 v2025.3.rc1-110-g8fbe2387"
|
||||
|
||||
17569
3rdparty/spirv-tools/include/generated/core_tables_body.inc
vendored
17569
3rdparty/spirv-tools/include/generated/core_tables_body.inc
vendored
File diff suppressed because it is too large
Load Diff
@@ -13,6 +13,7 @@ enum class PrintingClass : uint32_t {
|
||||
kDevice_Side_Enqueue,
|
||||
kExtension,
|
||||
kFunction,
|
||||
kGraph,
|
||||
kGroup,
|
||||
kImage,
|
||||
kMemory,
|
||||
@@ -43,6 +44,7 @@ enum Extension : uint32_t {
|
||||
kSPV_AMD_texture_gather_bias_lod,
|
||||
kSPV_ARM_cooperative_matrix_layouts,
|
||||
kSPV_ARM_core_builtins,
|
||||
kSPV_ARM_graph,
|
||||
kSPV_ARM_tensors,
|
||||
kSPV_EXT_arithmetic_fence,
|
||||
kSPV_EXT_demote_to_helper_invocation,
|
||||
@@ -91,6 +93,7 @@ enum Extension : uint32_t {
|
||||
kSPV_INTEL_fpga_memory_attributes,
|
||||
kSPV_INTEL_fpga_reg,
|
||||
kSPV_INTEL_function_pointers,
|
||||
kSPV_INTEL_function_variants,
|
||||
kSPV_INTEL_global_variable_fpga_decorations,
|
||||
kSPV_INTEL_global_variable_host_access,
|
||||
kSPV_INTEL_inline_assembly,
|
||||
@@ -182,6 +185,7 @@ enum Extension : uint32_t {
|
||||
kSPV_NV_stereo_view_rendering,
|
||||
kSPV_NV_tensor_addressing,
|
||||
kSPV_NV_viewport_array2,
|
||||
kSPV_QCOM_cooperative_matrix_conversion,
|
||||
kSPV_QCOM_image_processing,
|
||||
kSPV_QCOM_image_processing2,
|
||||
kSPV_QCOM_tile_shading,
|
||||
|
||||
@@ -80,6 +80,8 @@ typedef enum spv_result_t {
|
||||
SPV_ERROR_INVALID_DATA = -14, // Indicates data rules validation failure.
|
||||
SPV_ERROR_MISSING_EXTENSION = -15,
|
||||
SPV_ERROR_WRONG_VERSION = -16, // Indicates wrong SPIR-V version
|
||||
SPV_ERROR_FNVAR =
|
||||
-17, // Error related to SPV_INTEL_function_variants extension
|
||||
SPV_FORCE_32_BIT_ENUM(spv_result_t)
|
||||
} spv_result_t;
|
||||
|
||||
@@ -189,36 +191,24 @@ typedef enum spv_operand_type_t {
|
||||
SPV_OPERAND_TYPE_MEMORY_ACCESS, // SPIR-V Sec 3.26
|
||||
SPV_OPERAND_TYPE_FRAGMENT_SHADING_RATE, // SPIR-V Sec 3.FSR
|
||||
|
||||
// NOTE: New concrete enum values should be added at the end.
|
||||
// NOTE: New concrete enum values should be added at the end.
|
||||
|
||||
// The "optional" and "variable" operand types are only used internally by
|
||||
// the assembler and the binary parser.
|
||||
// There are two categories:
|
||||
// Optional : expands to 0 or 1 operand, like ? in regular expressions.
|
||||
// Variable : expands to 0, 1 or many operands or pairs of operands.
|
||||
// This is similar to * in regular expressions.
|
||||
// The "optional" and "variable" operand types are only used internally by
|
||||
// the assembler and the binary parser.
|
||||
// There are two categories:
|
||||
// Optional : expands to 0 or 1 operand, like ? in regular expressions.
|
||||
// Variable : expands to 0, 1 or many operands or pairs of operands.
|
||||
// This is similar to * in regular expressions.
|
||||
|
||||
// NOTE: These FIRST_* and LAST_* enum values are DEPRECATED.
|
||||
// The concept of "optional" and "variable" operand types are only intended
|
||||
// for use as an implementation detail of parsing SPIR-V, either in text or
|
||||
// binary form. Instead of using enum ranges, use characteristic function
|
||||
// spvOperandIsConcrete.
|
||||
// The use of enum value ranges in a public API makes it difficult to insert
|
||||
// new values into a range without also breaking binary compatibility.
|
||||
//
|
||||
// Macros for defining bounds on optional and variable operand types.
|
||||
// Any variable operand type is also optional.
|
||||
// TODO(dneto): Remove SPV_OPERAND_TYPE_FIRST_* and SPV_OPERAND_TYPE_LAST_*
|
||||
#define FIRST_OPTIONAL(ENUM) ENUM, SPV_OPERAND_TYPE_FIRST_OPTIONAL_TYPE = ENUM
|
||||
#define FIRST_VARIABLE(ENUM) ENUM, SPV_OPERAND_TYPE_FIRST_VARIABLE_TYPE = ENUM
|
||||
#define LAST_VARIABLE(ENUM) \
|
||||
ENUM, SPV_OPERAND_TYPE_LAST_VARIABLE_TYPE = ENUM, \
|
||||
SPV_OPERAND_TYPE_LAST_OPTIONAL_TYPE = ENUM
|
||||
// Use characteristic function spvOperandIsConcrete to classify the
|
||||
// operand types; when it returns false, the operand is optional or variable.
|
||||
//
|
||||
// Any variable operand type is also optional.
|
||||
|
||||
// An optional operand represents zero or one logical operands.
|
||||
// In an instruction definition, this may only appear at the end of the
|
||||
// operand types.
|
||||
FIRST_OPTIONAL(SPV_OPERAND_TYPE_OPTIONAL_ID),
|
||||
SPV_OPERAND_TYPE_OPTIONAL_ID,
|
||||
// An optional image operand type.
|
||||
SPV_OPERAND_TYPE_OPTIONAL_IMAGE,
|
||||
// An optional memory access type.
|
||||
@@ -243,7 +233,7 @@ typedef enum spv_operand_type_t {
|
||||
// A variable operand represents zero or more logical operands.
|
||||
// In an instruction definition, this may only appear at the end of the
|
||||
// operand types.
|
||||
FIRST_VARIABLE(SPV_OPERAND_TYPE_VARIABLE_ID),
|
||||
SPV_OPERAND_TYPE_VARIABLE_ID,
|
||||
SPV_OPERAND_TYPE_VARIABLE_LITERAL_INTEGER,
|
||||
// A sequence of zero or more pairs of (typed literal integer, Id).
|
||||
// Expands to zero or more:
|
||||
@@ -251,7 +241,7 @@ typedef enum spv_operand_type_t {
|
||||
// where the literal number must always be an integer of some sort.
|
||||
SPV_OPERAND_TYPE_VARIABLE_LITERAL_INTEGER_ID,
|
||||
// A sequence of zero or more pairs of (Id, Literal integer)
|
||||
LAST_VARIABLE(SPV_OPERAND_TYPE_VARIABLE_ID_LITERAL_INTEGER),
|
||||
SPV_OPERAND_TYPE_VARIABLE_ID_LITERAL_INTEGER,
|
||||
|
||||
// The following are concrete enum types from the DebugInfo extended
|
||||
// instruction set.
|
||||
@@ -344,6 +334,10 @@ typedef enum spv_operand_type_t {
|
||||
SPV_OPERAND_TYPE_TENSOR_OPERANDS,
|
||||
SPV_OPERAND_TYPE_OPTIONAL_TENSOR_OPERANDS,
|
||||
|
||||
// SPV_INTEL_function_variants
|
||||
SPV_OPERAND_TYPE_OPTIONAL_CAPABILITY,
|
||||
SPV_OPERAND_TYPE_VARIABLE_CAPABILITY,
|
||||
|
||||
// This is a sentinel value, and does not represent an operand type.
|
||||
// It should come last.
|
||||
SPV_OPERAND_TYPE_NUM_OPERAND_TYPES,
|
||||
@@ -370,6 +364,7 @@ typedef enum spv_ext_inst_type_t {
|
||||
SPV_EXT_INST_TYPE_NONSEMANTIC_CLSPVREFLECTION,
|
||||
SPV_EXT_INST_TYPE_NONSEMANTIC_SHADER_DEBUGINFO_100,
|
||||
SPV_EXT_INST_TYPE_NONSEMANTIC_VKSPREFLECTION,
|
||||
SPV_EXT_INST_TYPE_TOSA_001000_1,
|
||||
|
||||
// Multiple distinct extended instruction set types could return this
|
||||
// value, if they are prefixed with NonSemantic. and are otherwise
|
||||
@@ -438,7 +433,7 @@ typedef enum spv_binary_to_text_options_t {
|
||||
|
||||
// The default id bound is to the minimum value for the id limit
|
||||
// in the spir-v specification under the section "Universal Limits".
|
||||
const uint32_t kDefaultMaxIdBound = 0x3FFFFF;
|
||||
const static uint32_t kDefaultMaxIdBound = 0x3FFFFF;
|
||||
|
||||
// Structures
|
||||
|
||||
@@ -772,6 +767,7 @@ SPIRV_TOOLS_EXPORT void spvValidatorOptionsSetAllowOffsetTextureOperand(
|
||||
spv_validator_options options, bool val);
|
||||
|
||||
// Allow base operands of some bit operations to be non-32-bit wide.
|
||||
// Was added for VK_KHR_maintenance9
|
||||
SPIRV_TOOLS_EXPORT void spvValidatorOptionsSetAllowVulkan32BitBitwise(
|
||||
spv_validator_options options, bool val);
|
||||
|
||||
|
||||
@@ -133,6 +133,7 @@ class SPIRV_TOOLS_EXPORT ValidatorOptions {
|
||||
}
|
||||
|
||||
// Allow base operands of some bit operations to be non-32-bit wide.
|
||||
// Was added for VK_KHR_maintenance9
|
||||
void SetAllowVulkan32BitBitwise(bool val) {
|
||||
spvValidatorOptionsSetAllowVulkan32BitBitwise(options_, val);
|
||||
}
|
||||
|
||||
@@ -67,12 +67,36 @@ class SPIRV_TOOLS_EXPORT LinkerOptions {
|
||||
allow_ptr_type_mismatch_ = allow_ptr_type_mismatch;
|
||||
}
|
||||
|
||||
std::string GetFnVarTargetsCsv() const { return fnvar_targets_csv_; }
|
||||
void SetFnVarTargetsCsv(std::string fnvar_targets_csv) {
|
||||
fnvar_targets_csv_ = fnvar_targets_csv;
|
||||
}
|
||||
|
||||
std::string GetFnVarArchitecturesCsv() const {
|
||||
return fnvar_architectures_csv_;
|
||||
}
|
||||
void SetFnVarArchitecturesCsv(std::string fnvar_architectures_csv) {
|
||||
fnvar_architectures_csv_ = fnvar_architectures_csv;
|
||||
}
|
||||
|
||||
bool GetHasFnVarCapabilities() const { return has_fnvar_capabilities_; }
|
||||
void SetHasFnVarCapabilities(bool fnvar_capabilities) {
|
||||
has_fnvar_capabilities_ = fnvar_capabilities;
|
||||
}
|
||||
|
||||
std::vector<std::string> GetInFiles() const { return in_files_; }
|
||||
void SetInFiles(std::vector<std::string> in_files) { in_files_ = in_files; }
|
||||
|
||||
private:
|
||||
bool create_library_{false};
|
||||
bool verify_ids_{false};
|
||||
bool allow_partial_linkage_{false};
|
||||
bool use_highest_version_{false};
|
||||
bool allow_ptr_type_mismatch_{false};
|
||||
std::string fnvar_targets_csv_{""};
|
||||
std::string fnvar_architectures_csv_{""};
|
||||
bool has_fnvar_capabilities_ = false;
|
||||
std::vector<std::string> in_files_{{}};
|
||||
};
|
||||
|
||||
// Links one or more SPIR-V modules into a new SPIR-V module. That is, combine
|
||||
|
||||
@@ -1022,6 +1022,16 @@ Optimizer::PassToken CreateSplitCombinedImageSamplerPass();
|
||||
// This pass assumes binding numbers are not applid via decoration groups
|
||||
// (OpDecorationGroup).
|
||||
Optimizer::PassToken CreateResolveBindingConflictsPass();
|
||||
|
||||
// Create a pass to canonicalize IDs to improve compression of SPIR-V binary
|
||||
// files. The resulting modules have an increased ID range (IDs are not as
|
||||
// tightly packed around zero), but will compress better when multiple modules
|
||||
// are compressed together, since the compressor's dictionary can find better
|
||||
// cross module commonality. This pass should be run after most optimization
|
||||
// passes except for
|
||||
// --strip-debug because this pass will use OpName to canonicalize IDs. i.e. Run
|
||||
// --strip-debug after this pass.
|
||||
Optimizer::PassToken CreateCanonicalizeIdsPass();
|
||||
} // namespace spvtools
|
||||
|
||||
#endif // INCLUDE_SPIRV_TOOLS_OPTIMIZER_HPP_
|
||||
|
||||
5
3rdparty/spirv-tools/source/binary.cpp
vendored
5
3rdparty/spirv-tools/source/binary.cpp
vendored
@@ -636,6 +636,7 @@ spv_result_t Parser::parseOperand(size_t inst_offset,
|
||||
} break;
|
||||
|
||||
case SPV_OPERAND_TYPE_CAPABILITY:
|
||||
case SPV_OPERAND_TYPE_OPTIONAL_CAPABILITY:
|
||||
case SPV_OPERAND_TYPE_EXECUTION_MODEL:
|
||||
case SPV_OPERAND_TYPE_ADDRESSING_MODEL:
|
||||
case SPV_OPERAND_TYPE_MEMORY_MODEL:
|
||||
@@ -689,6 +690,8 @@ spv_result_t Parser::parseOperand(size_t inst_offset,
|
||||
parsed_operand.type = SPV_OPERAND_TYPE_PACKED_VECTOR_FORMAT;
|
||||
if (type == SPV_OPERAND_TYPE_OPTIONAL_FPENCODING)
|
||||
parsed_operand.type = SPV_OPERAND_TYPE_FPENCODING;
|
||||
if (type == SPV_OPERAND_TYPE_OPTIONAL_CAPABILITY)
|
||||
parsed_operand.type = SPV_OPERAND_TYPE_CAPABILITY;
|
||||
|
||||
const spvtools::OperandDesc* entry = nullptr;
|
||||
if (spvtools::LookupOperand(type, word, &entry)) {
|
||||
@@ -853,7 +856,7 @@ void Parser::recordNumberType(size_t inst_offset,
|
||||
info.type = SPV_NUMBER_FLOATING;
|
||||
info.bit_width = peekAt(inst_offset + 2);
|
||||
if (inst->num_words >= 4) {
|
||||
const spvtools::OperandDesc* desc;
|
||||
const spvtools::OperandDesc* desc = nullptr;
|
||||
spv_result_t status = spvtools::LookupOperand(
|
||||
SPV_OPERAND_TYPE_FPENCODING, peekAt(inst_offset + 3), &desc);
|
||||
if (status == SPV_SUCCESS) {
|
||||
|
||||
5
3rdparty/spirv-tools/source/disassemble.cpp
vendored
5
3rdparty/spirv-tools/source/disassemble.cpp
vendored
@@ -694,12 +694,12 @@ void InstructionDisassembler::EmitInstructionImpl(
|
||||
}
|
||||
|
||||
if (inst.result_id) {
|
||||
SetBlue();
|
||||
SetBlue(line);
|
||||
const std::string id_name = name_mapper_(inst.result_id);
|
||||
if (indent_)
|
||||
line << std::setw(std::max(0, indent_ - 3 - int(id_name.size())));
|
||||
line << "%" << id_name;
|
||||
ResetColor();
|
||||
ResetColor(line);
|
||||
line << " = ";
|
||||
} else {
|
||||
line << std::string(indent_, ' ');
|
||||
@@ -907,6 +907,7 @@ void InstructionDisassembler::EmitOperand(std::ostream& stream,
|
||||
stream << '"';
|
||||
} break;
|
||||
case SPV_OPERAND_TYPE_CAPABILITY:
|
||||
case SPV_OPERAND_TYPE_OPTIONAL_CAPABILITY:
|
||||
case SPV_OPERAND_TYPE_SOURCE_LANGUAGE:
|
||||
case SPV_OPERAND_TYPE_EXECUTION_MODEL:
|
||||
case SPV_OPERAND_TYPE_ADDRESSING_MODEL:
|
||||
|
||||
3
3rdparty/spirv-tools/source/ext_inst.cpp
vendored
3
3rdparty/spirv-tools/source/ext_inst.cpp
vendored
@@ -55,6 +55,9 @@ spv_ext_inst_type_t spvExtInstImportTypeGet(const char* name) {
|
||||
if (!strncmp("NonSemantic.VkspReflection.", name, 27)) {
|
||||
return SPV_EXT_INST_TYPE_NONSEMANTIC_VKSPREFLECTION;
|
||||
}
|
||||
if (!strcmp("TOSA.001000.1", name)) {
|
||||
return SPV_EXT_INST_TYPE_TOSA_001000_1;
|
||||
}
|
||||
// ensure to add any known non-semantic extended instruction sets
|
||||
// above this point, and update spvExtInstIsNonSemantic()
|
||||
if (!strncmp("NonSemantic.", name, 12)) {
|
||||
|
||||
14
3rdparty/spirv-tools/source/extensions.cpp
vendored
14
3rdparty/spirv-tools/source/extensions.cpp
vendored
@@ -24,18 +24,24 @@
|
||||
namespace spvtools {
|
||||
|
||||
std::string GetExtensionString(const spv_parsed_instruction_t* inst) {
|
||||
if (inst->opcode != static_cast<uint16_t>(spv::Op::OpExtension)) {
|
||||
if ((inst->opcode != static_cast<uint16_t>(spv::Op::OpExtension)) &&
|
||||
(inst->opcode !=
|
||||
static_cast<uint16_t>(spv::Op::OpConditionalExtensionINTEL))) {
|
||||
return "ERROR_not_op_extension";
|
||||
}
|
||||
|
||||
assert(inst->num_operands == 1);
|
||||
const bool is_conditional =
|
||||
inst->opcode ==
|
||||
static_cast<uint16_t>(spv::Op::OpConditionalExtensionINTEL);
|
||||
assert(inst->num_operands == (is_conditional ? 2 : 1));
|
||||
const uint16_t op_i = is_conditional ? 1 : 0;
|
||||
|
||||
const auto& operand = inst->operands[0];
|
||||
const auto& operand = inst->operands[op_i];
|
||||
assert(operand.type == SPV_OPERAND_TYPE_LITERAL_STRING);
|
||||
assert(inst->num_words > operand.offset);
|
||||
(void)operand; /* No unused variables in release builds. */
|
||||
|
||||
return spvDecodeLiteralStringOperand(*inst, 0);
|
||||
return spvDecodeLiteralStringOperand(*inst, op_i);
|
||||
}
|
||||
|
||||
std::string ExtensionSetToString(const ExtensionSet& extensions) {
|
||||
|
||||
1011
3rdparty/spirv-tools/source/link/fnvar.cpp
vendored
Normal file
1011
3rdparty/spirv-tools/source/link/fnvar.cpp
vendored
Normal file
File diff suppressed because it is too large
Load Diff
244
3rdparty/spirv-tools/source/link/fnvar.h
vendored
Normal file
244
3rdparty/spirv-tools/source/link/fnvar.h
vendored
Normal file
@@ -0,0 +1,244 @@
|
||||
// Copyright 2025 The Khronos Group Inc.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// Implementation of generating multitarget modules according to the
|
||||
// *SPV_INTEL_function_variants* extension
|
||||
//
|
||||
// Multitarget module is generated by linking separate modules: a base module
|
||||
// and variant modules containing device-specific variants of the functions in
|
||||
// the base module. The behavior is controlled by Comma-Separated Values (CSV)
|
||||
// files passed to the following flags:
|
||||
// --fnvar-targets: Required columns:
|
||||
// module - module file name
|
||||
// target - device target ISA value
|
||||
// features - feature values for the target separated by '/' (FEAT_SEP)
|
||||
// --fnvar-architectures: Required columns:
|
||||
// module - module file name
|
||||
// category - device category value
|
||||
// family - device family value
|
||||
// op - opcode of the comparison instruction
|
||||
// architecture - device architecture
|
||||
// The values (except module) are decimal strings with their meaning defined in
|
||||
// the 'targets registry' as described in the extension spec. The decimal
|
||||
// strings may only encode unsigned 32-bit integers (characters 0-9), possibly
|
||||
// with leading zeros.
|
||||
//
|
||||
// In addition, --fnvar-capabilities generates OpSpecConstantCapabilitiesINTEL
|
||||
// for each module with operands corresponding to the module's capabilities.
|
||||
//
|
||||
// Each line in the targets/architectures CSV file defines one
|
||||
// OpSpecConstant<Target/Architecture>INTEL instruction, the columns correspond
|
||||
// to the operands of these instructions. One module can have multiple lines, in
|
||||
// which case they are combined into a single boolean spec constant using
|
||||
// OpSpecConstantOp and OpLogicalOr (except when category and family in the
|
||||
// architectures CSV are the same, then the lines are combined with
|
||||
// OpLogicalAnd). For example, the following architectures CSV
|
||||
//
|
||||
// module,category,family,op,architecture
|
||||
// foo.spv,1,7,174,1
|
||||
// foo.spv,1,7,178,3
|
||||
// foo.spv,1,8,170,1
|
||||
//
|
||||
// is combined as follows:
|
||||
//
|
||||
// %53 = OpSpecConstantArchitectureINTEL %bool 1 7 174 1
|
||||
// %54 = OpSpecConstantArchitectureINTEL %bool 1 7 178 3
|
||||
// %55 = OpSpecConstantArchitectureINTEL %bool 1 8 170 1
|
||||
// %56 = OpSpecConstantOp %bool LogicalAnd %53 %54
|
||||
// %foo_spv = OpSpecConstantOp %bool LogicalOr %55 %56
|
||||
//
|
||||
// The %foo_spv is annotated with OpName "foo.spv" (the module's name) which
|
||||
// serves as an identifier to find the constant later. We cannot use IDs for it
|
||||
// because the IDs get shifted during linking.
|
||||
//
|
||||
// The first module passed to `spirv-link` is considered the 'base' module. For
|
||||
// example, if base module defines functions 'foo' and 'bar' and the other
|
||||
// modules define only 'foo', only the 'foo' is treated as a function variant
|
||||
// guarded by spec constants. The 'bar' function will be untouched and therefore
|
||||
// present for all variants. The function variants are matched by name, and
|
||||
// therefore they must either have an entry point, or an Export linkage
|
||||
// attribute.
|
||||
|
||||
#ifndef FNVAR_H
|
||||
#define FNVAR_H
|
||||
|
||||
#include <map>
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "source/opt/ir_context.h"
|
||||
#include "source/opt/module.h"
|
||||
#include "spirv-tools/linker.hpp"
|
||||
|
||||
namespace spvtools {
|
||||
|
||||
using opt::IRContext;
|
||||
using opt::Module;
|
||||
|
||||
// Map of instruction hash -> which variants are using the instruction (denoted
|
||||
// by the index to the variants vector)
|
||||
using FnVarUsage = std::unordered_map<size_t, std::vector<size_t>>;
|
||||
|
||||
// Map of base function call ID -> variant functions corresponding to the
|
||||
// called function (along with the variant name)
|
||||
using BaseFnCalls =
|
||||
std::map<uint32_t,
|
||||
std::vector<std::pair<std::string, const opt::Function*>>>;
|
||||
|
||||
constexpr char FNVAR_EXT_NAME[] = "SPV_INTEL_function_variants";
|
||||
constexpr uint32_t FNVAR_REGISTRY_VERSION = 0;
|
||||
constexpr char FEAT_SEP = '/';
|
||||
|
||||
struct FnVarArchDef {
|
||||
uint32_t category;
|
||||
uint32_t family;
|
||||
uint32_t op;
|
||||
uint32_t architecture;
|
||||
};
|
||||
|
||||
struct FnVarTargetDef {
|
||||
uint32_t target;
|
||||
std::vector<uint32_t> features;
|
||||
};
|
||||
|
||||
// Definition of a variant
|
||||
//
|
||||
// Stores architecture and target definitions inferred from lines in the CSV
|
||||
// files for a single module (as well as a pointer to the Module).
|
||||
class VariantDef {
|
||||
public:
|
||||
VariantDef(bool isbase, std::string nm, Module* mod)
|
||||
: is_base(isbase), name(nm), module(mod) {}
|
||||
|
||||
bool IsBase() const { return this->is_base; }
|
||||
std::string GetName() const { return this->name; }
|
||||
Module* GetModule() const { return this->module; }
|
||||
|
||||
void AddArchDef(uint32_t category, uint32_t family, uint32_t op,
|
||||
uint32_t architecture) {
|
||||
FnVarArchDef arch_def;
|
||||
arch_def.category = category;
|
||||
arch_def.family = family;
|
||||
arch_def.op = op;
|
||||
arch_def.architecture = architecture;
|
||||
this->arch_defs.push_back(arch_def);
|
||||
}
|
||||
const std::vector<FnVarArchDef>& GetArchDefs() const {
|
||||
return this->arch_defs;
|
||||
}
|
||||
|
||||
void AddTgtDef(uint32_t target, std::vector<uint32_t> features) {
|
||||
FnVarTargetDef tgt_def;
|
||||
tgt_def.target = target;
|
||||
tgt_def.features = features;
|
||||
this->tgt_defs.push_back(tgt_def);
|
||||
}
|
||||
const std::vector<FnVarTargetDef>& GetTgtDefs() const {
|
||||
return this->tgt_defs;
|
||||
}
|
||||
|
||||
void InferCapabilities() {
|
||||
for (const auto& cap_inst : module->capabilities()) {
|
||||
capabilities.insert(spv::Capability(cap_inst.GetOperand(0).words[0]));
|
||||
}
|
||||
}
|
||||
const std::set<spv::Capability>& GetCapabilities() const {
|
||||
return this->capabilities;
|
||||
}
|
||||
|
||||
private:
|
||||
bool is_base;
|
||||
std::string name;
|
||||
Module* module;
|
||||
std::vector<FnVarTargetDef> tgt_defs;
|
||||
std::vector<FnVarArchDef> arch_defs;
|
||||
std::set<spv::Capability> capabilities;
|
||||
};
|
||||
|
||||
// Collection of VariantDef instances
|
||||
//
|
||||
// Apart from being a wrapper around a vector of VariantDef instances, it
|
||||
// defines the main API for generating SPV_INTEL_function_variants instructions
|
||||
// based on the CSV files.
|
||||
class VariantDefs {
|
||||
public:
|
||||
// Returns last error message.
|
||||
std::string GetErr() { return err_.str(); }
|
||||
|
||||
// Processes CSV files passed to the CLI and populate _variants.
|
||||
//
|
||||
// Returns true on success, false on error.
|
||||
bool ProcessFnVar(const LinkerOptions& options,
|
||||
const std::vector<Module*>& modules);
|
||||
|
||||
// Analyses each variant def module and generates those instructions that are
|
||||
// module-specific, ie., not requiring knowledge from other modules.
|
||||
//
|
||||
// Returns true on success, false on error.
|
||||
bool ProcessVariantDefs();
|
||||
|
||||
// Generates basic instructions required for this extension to work.
|
||||
void GenerateHeader(IRContext* linked_context);
|
||||
|
||||
// Generates instructions from this extension that result from combining
|
||||
// several variant def modules.
|
||||
void CombineVariantInstructions(IRContext* linked_context);
|
||||
|
||||
private:
|
||||
// Adds a boolean type to every module if there is none.
|
||||
//
|
||||
// These are necessary for spec constants.
|
||||
void EnsureBoolType();
|
||||
|
||||
// Collects which combinable instructions are defined in which modules
|
||||
void CollectVarInsts();
|
||||
|
||||
// Generates OpSpecConstant<Target/Architecture/Capabilities>INTEL and
|
||||
// combines them as necessary. Also converts entry points to conditional ones
|
||||
// and decorates module-specific instructions with ConditionalINTEL.
|
||||
//
|
||||
// Returns true on success, false on error.
|
||||
bool GenerateFnVarConstants();
|
||||
|
||||
// Determines which functions in the base module are called by which function
|
||||
// variants.
|
||||
void CollectBaseFnCalls();
|
||||
|
||||
// Combines OpFunctionCall instructions collected with CollectBaseFnCalls()
|
||||
// using conditional copy.
|
||||
void CombineBaseFnCalls(IRContext* linked_context);
|
||||
|
||||
// Decorates instructions shared between modules with ConditionalINTEL or
|
||||
// generates conditional capabilities and extensions, depending on which
|
||||
// variants are used by each.
|
||||
void CombineInstructions(IRContext* linked_context);
|
||||
|
||||
// Accumulates all errors encountered during processing.
|
||||
std::stringstream err_;
|
||||
|
||||
// Collection of VariantDef instances
|
||||
std::vector<VariantDef> variant_defs_;
|
||||
|
||||
// Used for combining OpFunctionCall instructions
|
||||
BaseFnCalls base_fn_calls_;
|
||||
|
||||
// Used for determining which function variant uses which (applicable)
|
||||
// instruction
|
||||
FnVarUsage fnvar_usage_;
|
||||
};
|
||||
|
||||
} // namespace spvtools
|
||||
|
||||
#endif // FNVAR_H
|
||||
42
3rdparty/spirv-tools/source/link/linker.cpp
vendored
42
3rdparty/spirv-tools/source/link/linker.cpp
vendored
@@ -15,9 +15,10 @@
|
||||
#include "spirv-tools/linker.hpp"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstdint>
|
||||
#include <cstdio>
|
||||
#include <cstring>
|
||||
#include <iostream>
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
#include <numeric>
|
||||
#include <string>
|
||||
@@ -26,18 +27,17 @@
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "fnvar.h"
|
||||
#include "source/diagnostic.h"
|
||||
#include "source/opt/build_module.h"
|
||||
#include "source/opt/compact_ids_pass.h"
|
||||
#include "source/opt/decoration_manager.h"
|
||||
#include "source/opt/ir_builder.h"
|
||||
#include "source/opt/ir_loader.h"
|
||||
#include "source/opt/pass_manager.h"
|
||||
#include "source/opt/remove_duplicates_pass.h"
|
||||
#include "source/opt/remove_unused_interface_variables_pass.h"
|
||||
#include "source/opt/type_manager.h"
|
||||
#include "source/spirv_constant.h"
|
||||
#include "source/spirv_target_env.h"
|
||||
#include "source/table2.h"
|
||||
#include "source/util/make_unique.h"
|
||||
#include "source/util/string_utils.h"
|
||||
@@ -328,7 +328,10 @@ spv_result_t MergeModules(const MessageConsumer& consumer,
|
||||
for (const auto& module : input_modules)
|
||||
for (const auto& inst : module->entry_points()) {
|
||||
const uint32_t model = inst.GetSingleWordInOperand(0);
|
||||
const std::string name = inst.GetInOperand(2).AsString();
|
||||
const std::string name =
|
||||
inst.opcode() == spv::Op::OpConditionalEntryPointINTEL
|
||||
? inst.GetOperand(3).AsString()
|
||||
: inst.GetOperand(2).AsString();
|
||||
const auto i = std::find_if(
|
||||
entry_points.begin(), entry_points.end(),
|
||||
[model, name](const std::pair<uint32_t, std::string>& v) {
|
||||
@@ -728,8 +731,7 @@ spv_result_t VerifyLimits(const MessageConsumer& consumer,
|
||||
if (max_id_bound >= SPV_LIMIT_RESULT_ID_BOUND)
|
||||
DiagnosticStream({0u, 0u, 4u}, consumer, "", SPV_WARNING)
|
||||
<< "The minimum limit of IDs, " << (SPV_LIMIT_RESULT_ID_BOUND - 1)
|
||||
<< ", was exceeded:"
|
||||
<< " " << max_id_bound << " is the current ID bound.\n"
|
||||
<< ", was exceeded: " << max_id_bound << " is the current ID bound.\n"
|
||||
<< "The resulting module might not be supported by all "
|
||||
"implementations.";
|
||||
|
||||
@@ -740,8 +742,8 @@ spv_result_t VerifyLimits(const MessageConsumer& consumer,
|
||||
if (num_global_values >= SPV_LIMIT_GLOBAL_VARIABLES_MAX)
|
||||
DiagnosticStream(position, consumer, "", SPV_WARNING)
|
||||
<< "The minimum limit of global values, "
|
||||
<< (SPV_LIMIT_GLOBAL_VARIABLES_MAX - 1) << ", was exceeded;"
|
||||
<< " " << num_global_values << " global values were found.\n"
|
||||
<< (SPV_LIMIT_GLOBAL_VARIABLES_MAX - 1) << ", was exceeded; "
|
||||
<< num_global_values << " global values were found.\n"
|
||||
<< "The resulting module might not be supported by all "
|
||||
"implementations.";
|
||||
|
||||
@@ -853,6 +855,22 @@ spv_result_t Link(const Context& context, const uint32_t* const* binaries,
|
||||
ir_contexts.push_back(std::move(ir_context));
|
||||
}
|
||||
|
||||
const bool make_multitarget = !options.GetFnVarArchitecturesCsv().empty() ||
|
||||
!options.GetFnVarTargetsCsv().empty();
|
||||
|
||||
VariantDefs variant_defs;
|
||||
|
||||
if (make_multitarget) {
|
||||
if (!variant_defs.ProcessFnVar(options, modules)) {
|
||||
return DiagnosticStream(position, consumer, "", SPV_ERROR_FNVAR)
|
||||
<< variant_defs.GetErr();
|
||||
}
|
||||
if (!variant_defs.ProcessVariantDefs()) {
|
||||
return DiagnosticStream(position, consumer, "", SPV_ERROR_FNVAR)
|
||||
<< variant_defs.GetErr();
|
||||
}
|
||||
}
|
||||
|
||||
// Phase 1: Shift the IDs used in each binary so that they occupy a disjoint
|
||||
// range from the other binaries, and compute the new ID bound.
|
||||
uint32_t max_id_bound = 0u;
|
||||
@@ -866,6 +884,10 @@ spv_result_t Link(const Context& context, const uint32_t* const* binaries,
|
||||
IRContext linked_context(c_context->target_env, consumer);
|
||||
linked_context.module()->SetHeader(header);
|
||||
|
||||
if (make_multitarget) {
|
||||
variant_defs.GenerateHeader(&linked_context);
|
||||
}
|
||||
|
||||
// Phase 3: Merge all the binaries into a single one.
|
||||
res = MergeModules(consumer, modules, &linked_context);
|
||||
if (res != SPV_SUCCESS) return res;
|
||||
@@ -882,6 +904,10 @@ spv_result_t Link(const Context& context, const uint32_t* const* binaries,
|
||||
opt::Pass::Status pass_res = manager.Run(&linked_context);
|
||||
if (pass_res == opt::Pass::Status::Failure) return SPV_ERROR_INVALID_DATA;
|
||||
|
||||
if (make_multitarget) {
|
||||
variant_defs.CombineVariantInstructions(&linked_context);
|
||||
}
|
||||
|
||||
// Phase 5: Find the import/export pairs
|
||||
LinkageTable linkings_to_do;
|
||||
res = GetImportExportPairs(consumer, linked_context,
|
||||
|
||||
15
3rdparty/spirv-tools/source/mimalloc.cpp
vendored
Normal file
15
3rdparty/spirv-tools/source/mimalloc.cpp
vendored
Normal file
@@ -0,0 +1,15 @@
|
||||
// Copyright (c) 2025 The Khronos Group Inc.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "mimalloc-new-delete.h"
|
||||
10
3rdparty/spirv-tools/source/opcode.cpp
vendored
10
3rdparty/spirv-tools/source/opcode.cpp
vendored
@@ -120,6 +120,9 @@ int32_t spvOpcodeIsSpecConstant(const spv::Op opcode) {
|
||||
case spv::Op::OpSpecConstantComposite:
|
||||
case spv::Op::OpSpecConstantCompositeReplicateEXT:
|
||||
case spv::Op::OpSpecConstantOp:
|
||||
case spv::Op::OpSpecConstantArchitectureINTEL:
|
||||
case spv::Op::OpSpecConstantTargetINTEL:
|
||||
case spv::Op::OpSpecConstantCapabilitiesINTEL:
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
@@ -144,6 +147,12 @@ int32_t spvOpcodeIsConstant(const spv::Op opcode) {
|
||||
case spv::Op::OpSpecConstantCompositeReplicateEXT:
|
||||
case spv::Op::OpSpecConstantOp:
|
||||
case spv::Op::OpSpecConstantStringAMDX:
|
||||
case spv::Op::OpGraphConstantARM:
|
||||
case spv::Op::OpAsmTargetINTEL:
|
||||
case spv::Op::OpAsmINTEL:
|
||||
case spv::Op::OpSpecConstantArchitectureINTEL:
|
||||
case spv::Op::OpSpecConstantTargetINTEL:
|
||||
case spv::Op::OpSpecConstantCapabilitiesINTEL:
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
@@ -264,6 +273,7 @@ int32_t spvOpcodeGeneratesType(spv::Op op) {
|
||||
case spv::Op::OpTypeTensorViewNV:
|
||||
case spv::Op::OpTypeTensorARM:
|
||||
case spv::Op::OpTypeTaskSequenceINTEL:
|
||||
case spv::Op::OpTypeGraphARM:
|
||||
return true;
|
||||
default:
|
||||
// In particular, OpTypeForwardPointer does not generate a type,
|
||||
|
||||
13
3rdparty/spirv-tools/source/operand.cpp
vendored
13
3rdparty/spirv-tools/source/operand.cpp
vendored
@@ -111,6 +111,7 @@ const char* spvOperandTypeStr(spv_operand_type_t type) {
|
||||
case SPV_OPERAND_TYPE_KERNEL_PROFILING_INFO:
|
||||
return "kernel profiling info";
|
||||
case SPV_OPERAND_TYPE_CAPABILITY:
|
||||
case SPV_OPERAND_TYPE_OPTIONAL_CAPABILITY:
|
||||
return "capability";
|
||||
case SPV_OPERAND_TYPE_RAY_FLAGS:
|
||||
return "ray flags";
|
||||
@@ -394,6 +395,7 @@ bool spvOperandIsOptional(spv_operand_type_t type) {
|
||||
case SPV_OPERAND_TYPE_OPTIONAL_RAW_ACCESS_CHAIN_OPERANDS:
|
||||
case SPV_OPERAND_TYPE_OPTIONAL_FPENCODING:
|
||||
case SPV_OPERAND_TYPE_OPTIONAL_TENSOR_OPERANDS:
|
||||
case SPV_OPERAND_TYPE_OPTIONAL_CAPABILITY:
|
||||
return true;
|
||||
default:
|
||||
break;
|
||||
@@ -408,6 +410,7 @@ bool spvOperandIsVariable(spv_operand_type_t type) {
|
||||
case SPV_OPERAND_TYPE_VARIABLE_LITERAL_INTEGER:
|
||||
case SPV_OPERAND_TYPE_VARIABLE_LITERAL_INTEGER_ID:
|
||||
case SPV_OPERAND_TYPE_VARIABLE_ID_LITERAL_INTEGER:
|
||||
case SPV_OPERAND_TYPE_VARIABLE_CAPABILITY:
|
||||
return true;
|
||||
default:
|
||||
break;
|
||||
@@ -439,6 +442,10 @@ bool spvExpandOperandSequenceOnce(spv_operand_type_t type,
|
||||
pattern->push_back(SPV_OPERAND_TYPE_LITERAL_INTEGER);
|
||||
pattern->push_back(SPV_OPERAND_TYPE_OPTIONAL_ID);
|
||||
return true;
|
||||
case SPV_OPERAND_TYPE_VARIABLE_CAPABILITY:
|
||||
pattern->push_back(type);
|
||||
pattern->push_back(SPV_OPERAND_TYPE_OPTIONAL_CAPABILITY);
|
||||
return true;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
@@ -521,6 +528,9 @@ std::function<bool(unsigned)> spvOperandCanBeForwardDeclaredFunction(
|
||||
case spv::Op::OpMemberDecorateStringGOOGLE:
|
||||
case spv::Op::OpBranch:
|
||||
case spv::Op::OpLoopMerge:
|
||||
case spv::Op::OpConditionalEntryPointINTEL:
|
||||
case spv::Op::OpConditionalCapabilityINTEL:
|
||||
case spv::Op::OpConditionalExtensionINTEL:
|
||||
out = [](unsigned) { return true; };
|
||||
break;
|
||||
case spv::Op::OpGroupDecorate:
|
||||
@@ -571,6 +581,9 @@ std::function<bool(unsigned)> spvOperandCanBeForwardDeclaredFunction(
|
||||
// approximate, due to variable operands
|
||||
out = [](unsigned index) { return index > 6; };
|
||||
break;
|
||||
case spv::Op::OpGraphEntryPointARM:
|
||||
out = [](unsigned index) { return index == 0; };
|
||||
break;
|
||||
default:
|
||||
out = [](unsigned) { return false; };
|
||||
break;
|
||||
|
||||
@@ -44,6 +44,9 @@ constexpr uint32_t kExtInstSetInIdx = 0;
|
||||
constexpr uint32_t kExtInstOpInIdx = 1;
|
||||
constexpr uint32_t kInterpolantInIdx = 2;
|
||||
constexpr uint32_t kCooperativeMatrixLoadSourceAddrInIdx = 0;
|
||||
constexpr uint32_t kDebugValueLocalVariable = 2;
|
||||
constexpr uint32_t kDebugValueValue = 3;
|
||||
constexpr uint32_t kDebugValueExpression = 4;
|
||||
|
||||
// Sorting functor to present annotation instructions in an easy-to-process
|
||||
// order. The functor orders by opcode first and falls back on unique id
|
||||
@@ -277,9 +280,53 @@ bool AggressiveDCEPass::AggressiveDCE(Function* func) {
|
||||
live_local_vars_.clear();
|
||||
InitializeWorkList(func, structured_order);
|
||||
ProcessWorkList(func);
|
||||
ProcessDebugInformation(structured_order);
|
||||
ProcessWorkList(func);
|
||||
return KillDeadInstructions(func, structured_order);
|
||||
}
|
||||
|
||||
void AggressiveDCEPass::ProcessDebugInformation(
|
||||
std::list<BasicBlock*>& structured_order) {
|
||||
for (auto bi = structured_order.begin(); bi != structured_order.end(); bi++) {
|
||||
(*bi)->ForEachInst([this](Instruction* inst) {
|
||||
// DebugDeclare is not dead. It must be converted to DebugValue in a
|
||||
// later pass
|
||||
if (inst->IsNonSemanticInstruction() &&
|
||||
inst->GetShader100DebugOpcode() ==
|
||||
NonSemanticShaderDebugInfo100DebugDeclare) {
|
||||
AddToWorklist(inst);
|
||||
return;
|
||||
}
|
||||
|
||||
// If the Value of a DebugValue is killed, set Value operand to Undef
|
||||
if (inst->IsNonSemanticInstruction() &&
|
||||
inst->GetShader100DebugOpcode() ==
|
||||
NonSemanticShaderDebugInfo100DebugValue) {
|
||||
uint32_t id = inst->GetSingleWordInOperand(kDebugValueValue);
|
||||
auto def = get_def_use_mgr()->GetDef(id);
|
||||
if (!live_insts_.Set(def->unique_id())) {
|
||||
AddToWorklist(inst);
|
||||
context()->get_def_use_mgr()->UpdateDefUse(inst);
|
||||
worklist_.push(def);
|
||||
def->SetOpcode(spv::Op::OpUndef);
|
||||
def->SetInOperands({});
|
||||
id = inst->GetSingleWordInOperand(kDebugValueLocalVariable);
|
||||
auto localVar = get_def_use_mgr()->GetDef(id);
|
||||
AddToWorklist(localVar);
|
||||
context()->get_def_use_mgr()->UpdateDefUse(localVar);
|
||||
AddOperandsToWorkList(localVar);
|
||||
context()->get_def_use_mgr()->UpdateDefUse(def);
|
||||
id = inst->GetSingleWordInOperand(kDebugValueExpression);
|
||||
auto expression = get_def_use_mgr()->GetDef(id);
|
||||
AddToWorklist(expression);
|
||||
context()->get_def_use_mgr()->UpdateDefUse(expression);
|
||||
return;
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
bool AggressiveDCEPass::KillDeadInstructions(
|
||||
const Function* func, std::list<BasicBlock*>& structured_order) {
|
||||
bool modified = false;
|
||||
@@ -916,8 +963,17 @@ bool AggressiveDCEPass::ProcessGlobalValues() {
|
||||
}
|
||||
// Save debug build identifier even if no other instructions refer to it.
|
||||
if (dbg.GetShader100DebugOpcode() ==
|
||||
NonSemanticShaderDebugInfo100DebugBuildIdentifier)
|
||||
NonSemanticShaderDebugInfo100DebugBuildIdentifier) {
|
||||
// The debug build identifier refers to other instructions that
|
||||
// can potentially be removed, they also need to be kept alive.
|
||||
dbg.ForEachInId([this](const uint32_t* id) {
|
||||
Instruction* ref_inst = get_def_use_mgr()->GetDef(*id);
|
||||
if (ref_inst) {
|
||||
live_insts_.Set(ref_inst->unique_id());
|
||||
}
|
||||
});
|
||||
continue;
|
||||
}
|
||||
to_kill_.push_back(&dbg);
|
||||
modified = true;
|
||||
}
|
||||
|
||||
@@ -150,6 +150,12 @@ class AggressiveDCEPass : public MemPass {
|
||||
// will be empty at the end.
|
||||
void ProcessWorkList(Function* func);
|
||||
|
||||
// Process each DebugDeclare and DebugValue in |func| that has not been
|
||||
// marked as live in the work list. DebugDeclare's are marked live now, and
|
||||
// DebugValue Value operands are set to OpUndef. The work list will be empty
|
||||
// at the end.
|
||||
void ProcessDebugInformation(std::list<BasicBlock*>& structured_order);
|
||||
|
||||
// Kills any instructions in |func| that have not been marked as live.
|
||||
bool KillDeadInstructions(const Function* func,
|
||||
std::list<BasicBlock*>& structured_order);
|
||||
|
||||
516
3rdparty/spirv-tools/source/opt/canonicalize_ids_pass.cpp
vendored
Normal file
516
3rdparty/spirv-tools/source/opt/canonicalize_ids_pass.cpp
vendored
Normal file
@@ -0,0 +1,516 @@
|
||||
// Copyright (c) 2025 LunarG Inc.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "source/opt/canonicalize_ids_pass.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <limits>
|
||||
|
||||
namespace spvtools {
|
||||
namespace opt {
|
||||
|
||||
Pass::Status CanonicalizeIdsPass::Process() {
|
||||
// Initialize the new ID map.
|
||||
new_id_.resize(GetBound(), unused_);
|
||||
|
||||
// Scan the IDs and set to unmapped.
|
||||
ScanIds();
|
||||
|
||||
// Create new IDs for types and consts.
|
||||
CanonicalizeTypeAndConst();
|
||||
|
||||
// Create new IDs for names.
|
||||
CanonicalizeNames();
|
||||
|
||||
// Create new IDs for functions.
|
||||
CanonicalizeFunctions();
|
||||
|
||||
// Create new IDs for everything else.
|
||||
CanonicalizeRemainders();
|
||||
|
||||
// Apply the new IDs to the module.
|
||||
auto const modified = ApplyMap();
|
||||
|
||||
// Update bound in the header.
|
||||
if (modified) {
|
||||
UpdateBound();
|
||||
}
|
||||
|
||||
return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange;
|
||||
}
|
||||
|
||||
void CanonicalizeIdsPass::ScanIds() {
|
||||
get_module()->ForEachInst(
|
||||
[this](Instruction* inst) {
|
||||
// Look for types and constants.
|
||||
if (spvOpcodeGeneratesType(inst->opcode()) ||
|
||||
spvOpcodeIsConstant(inst->opcode())) {
|
||||
type_and_const_ids_.push_back(inst->result_id());
|
||||
SetNewId(inst->result_id(), unmapped_);
|
||||
}
|
||||
// Look for names.
|
||||
else if (inst->opcode() == spv::Op::OpName) {
|
||||
// store name string in map so that we can compute the hash later
|
||||
auto const name = inst->GetOperand(1).AsString();
|
||||
auto const target = inst->GetSingleWordInOperand(0);
|
||||
name_ids_[name] = target;
|
||||
SetNewId(target, unmapped_);
|
||||
}
|
||||
// Look for function IDs.
|
||||
else if (inst->opcode() == spv::Op::OpFunction) {
|
||||
auto const res_id = inst->result_id();
|
||||
function_ids_.push_back(res_id);
|
||||
SetNewId(res_id, unmapped_);
|
||||
}
|
||||
// Look for remaining result IDs.
|
||||
else if (inst->HasResultId()) {
|
||||
auto const res_id = inst->result_id();
|
||||
SetNewId(res_id, unmapped_);
|
||||
}
|
||||
},
|
||||
true);
|
||||
}
|
||||
|
||||
void CanonicalizeIdsPass::CanonicalizeTypeAndConst() {
|
||||
// Remap type IDs.
|
||||
static constexpr std::uint32_t soft_type_id_limit = 3011; // small prime.
|
||||
static constexpr std::uint32_t first_mapped_id = 8; // offset into ID space
|
||||
for (auto const id : type_and_const_ids_) {
|
||||
if (!IsOldIdUnmapped(id)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Compute the hash value.
|
||||
auto const hash_value = HashTypeAndConst(id);
|
||||
if (hash_value != unmapped_) {
|
||||
SetNewId(id, hash_value % soft_type_id_limit + first_mapped_id);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Hash types to canonical values. This can return ID collisions (it's a bit
|
||||
// inevitable): it's up to the caller to handle that gracefully.
|
||||
spv::Id CanonicalizeIdsPass::HashTypeAndConst(spv::Id const id) const {
|
||||
spv::Id value = 0;
|
||||
|
||||
auto const inst = get_def_use_mgr()->GetDef(id);
|
||||
auto const op_code = inst->opcode();
|
||||
switch (op_code) {
|
||||
case spv::Op::OpTypeVoid:
|
||||
value = 0;
|
||||
break;
|
||||
case spv::Op::OpTypeBool:
|
||||
value = 1;
|
||||
break;
|
||||
case spv::Op::OpTypeInt: {
|
||||
auto const signedness = inst->GetSingleWordOperand(2);
|
||||
value = 3 + signedness;
|
||||
break;
|
||||
}
|
||||
case spv::Op::OpTypeFloat:
|
||||
value = 5;
|
||||
break;
|
||||
case spv::Op::OpTypeVector: {
|
||||
auto const component_type = inst->GetSingleWordOperand(1);
|
||||
auto const component_count = inst->GetSingleWordOperand(2);
|
||||
value = 6 + HashTypeAndConst(component_type) * (component_count - 1);
|
||||
break;
|
||||
}
|
||||
case spv::Op::OpTypeMatrix: {
|
||||
auto const column_type = inst->GetSingleWordOperand(1);
|
||||
auto const column_count = inst->GetSingleWordOperand(2);
|
||||
value = 30 + HashTypeAndConst(column_type) * (column_count - 1);
|
||||
break;
|
||||
}
|
||||
case spv::Op::OpTypeImage: {
|
||||
// TODO: Why isn't the format used to compute the hash value?
|
||||
auto const sampled_type = inst->GetSingleWordOperand(1);
|
||||
auto const dim = inst->GetSingleWordOperand(2);
|
||||
auto const depth = inst->GetSingleWordOperand(3);
|
||||
auto const arrayed = inst->GetSingleWordOperand(4);
|
||||
auto const ms = inst->GetSingleWordOperand(5);
|
||||
auto const sampled = inst->GetSingleWordOperand(6);
|
||||
value = 120 + HashTypeAndConst(sampled_type) + dim + depth * 8 * 16 +
|
||||
arrayed * 4 * 16 + ms * 2 * 16 + sampled * 1 * 16;
|
||||
break;
|
||||
}
|
||||
case spv::Op::OpTypeSampler:
|
||||
value = 500;
|
||||
break;
|
||||
case spv::Op::OpTypeSampledImage:
|
||||
value = 502;
|
||||
break;
|
||||
case spv::Op::OpTypeArray: {
|
||||
auto const element_type = inst->GetSingleWordOperand(1);
|
||||
auto const length = inst->GetSingleWordOperand(2);
|
||||
value = 501 + HashTypeAndConst(element_type) * length;
|
||||
break;
|
||||
}
|
||||
case spv::Op::OpTypeRuntimeArray: {
|
||||
auto const element_type = inst->GetSingleWordOperand(1);
|
||||
value = 5000 + HashTypeAndConst(element_type);
|
||||
break;
|
||||
}
|
||||
case spv::Op::OpTypeStruct:
|
||||
value = 10000;
|
||||
for (uint32_t w = 1; w < inst->NumOperandWords(); ++w) {
|
||||
value += (w + 1) * HashTypeAndConst(inst->GetSingleWordOperand(w));
|
||||
}
|
||||
break;
|
||||
case spv::Op::OpTypeOpaque: {
|
||||
// TODO: Name is a literal that may have more than one word.
|
||||
auto const name = inst->GetSingleWordOperand(1);
|
||||
value = 6000 + name;
|
||||
break;
|
||||
}
|
||||
case spv::Op::OpTypePointer: {
|
||||
auto const type = inst->GetSingleWordOperand(2);
|
||||
value = 100000 + HashTypeAndConst(type);
|
||||
break;
|
||||
}
|
||||
case spv::Op::OpTypeFunction:
|
||||
value = 200000;
|
||||
for (uint32_t w = 1; w < inst->NumOperandWords(); ++w) {
|
||||
value += (w + 1) * HashTypeAndConst(inst->GetSingleWordOperand(w));
|
||||
}
|
||||
break;
|
||||
case spv::Op::OpTypeEvent:
|
||||
value = 300000;
|
||||
break;
|
||||
case spv::Op::OpTypeDeviceEvent:
|
||||
value = 300001;
|
||||
break;
|
||||
case spv::Op::OpTypeReserveId:
|
||||
value = 300002;
|
||||
break;
|
||||
case spv::Op::OpTypeQueue:
|
||||
value = 300003;
|
||||
break;
|
||||
case spv::Op::OpTypePipe:
|
||||
value = 300004;
|
||||
break;
|
||||
case spv::Op::OpTypePipeStorage:
|
||||
value = 300005;
|
||||
break;
|
||||
case spv::Op::OpTypeNamedBarrier:
|
||||
value = 300006;
|
||||
break;
|
||||
case spv::Op::OpConstantTrue:
|
||||
value = 300007;
|
||||
break;
|
||||
case spv::Op::OpConstantFalse:
|
||||
value = 300008;
|
||||
break;
|
||||
case spv::Op::OpTypeRayQueryKHR:
|
||||
value = 300009;
|
||||
break;
|
||||
case spv::Op::OpTypeAccelerationStructureKHR:
|
||||
value = 300010;
|
||||
break;
|
||||
// Don't map the following types.
|
||||
// TODO: These types were not remapped in the glslang version of the
|
||||
// remapper. Support should be added as necessary.
|
||||
case spv::Op::OpTypeCooperativeMatrixNV:
|
||||
case spv::Op::OpTypeCooperativeMatrixKHR:
|
||||
case spv::Op::OpTypeCooperativeVectorNV:
|
||||
case spv::Op::OpTypeHitObjectNV:
|
||||
case spv::Op::OpTypeUntypedPointerKHR:
|
||||
case spv::Op::OpTypeNodePayloadArrayAMDX:
|
||||
case spv::Op::OpTypeTensorLayoutNV:
|
||||
case spv::Op::OpTypeTensorViewNV:
|
||||
case spv::Op::OpTypeTensorARM:
|
||||
case spv::Op::OpTypeTaskSequenceINTEL:
|
||||
value = unmapped_;
|
||||
break;
|
||||
case spv::Op::OpConstant: {
|
||||
auto const result_type = inst->GetSingleWordOperand(0);
|
||||
value = 400011 + HashTypeAndConst(result_type);
|
||||
auto const literal = inst->GetOperand(2);
|
||||
for (uint32_t w = 0; w < literal.words.size(); ++w) {
|
||||
value += (w + 3) * literal.words[w];
|
||||
}
|
||||
break;
|
||||
}
|
||||
case spv::Op::OpConstantComposite: {
|
||||
auto const result_type = inst->GetSingleWordOperand(0);
|
||||
value = 300011 + HashTypeAndConst(result_type);
|
||||
for (uint32_t w = 2; w < inst->NumOperandWords(); ++w) {
|
||||
value += (w + 1) * HashTypeAndConst(inst->GetSingleWordOperand(w));
|
||||
}
|
||||
break;
|
||||
}
|
||||
case spv::Op::OpConstantNull: {
|
||||
auto const result_type = inst->GetSingleWordOperand(0);
|
||||
value = 500009 + HashTypeAndConst(result_type);
|
||||
break;
|
||||
}
|
||||
case spv::Op::OpConstantSampler: {
|
||||
auto const result_type = inst->GetSingleWordOperand(0);
|
||||
value = 600011 + HashTypeAndConst(result_type);
|
||||
for (uint32_t w = 2; w < inst->NumOperandWords(); ++w) {
|
||||
value += (w + 1) * inst->GetSingleWordOperand(w);
|
||||
}
|
||||
break;
|
||||
}
|
||||
// Don't map the following constants.
|
||||
// TODO: These constants were not remapped in the glslang version of the
|
||||
// remapper. Support should be added as necessary.
|
||||
case spv::Op::OpConstantCompositeReplicateEXT:
|
||||
case spv::Op::OpConstantFunctionPointerINTEL:
|
||||
case spv::Op::OpConstantStringAMDX:
|
||||
case spv::Op::OpSpecConstantTrue:
|
||||
case spv::Op::OpSpecConstantFalse:
|
||||
case spv::Op::OpSpecConstant:
|
||||
case spv::Op::OpSpecConstantComposite:
|
||||
case spv::Op::OpSpecConstantCompositeReplicateEXT:
|
||||
case spv::Op::OpSpecConstantOp:
|
||||
case spv::Op::OpSpecConstantStringAMDX:
|
||||
value = unmapped_;
|
||||
break;
|
||||
// TODO: Add additional types/constants as needed. See
|
||||
// spvOpcodeGeneratesType and spvOpcodeIsConstant.
|
||||
default:
|
||||
context()->consumer()(SPV_MSG_WARNING, "", {0, 0, 0},
|
||||
"unhandled opcode will not be canonicalized");
|
||||
break;
|
||||
}
|
||||
|
||||
return value;
|
||||
}
|
||||
|
||||
void CanonicalizeIdsPass::CanonicalizeNames() {
|
||||
static constexpr std::uint32_t soft_type_id_limit = 3011; // Small prime.
|
||||
static constexpr std::uint32_t first_mapped_id =
|
||||
3019; // Offset into ID space.
|
||||
|
||||
for (auto const& [name, target] : name_ids_) {
|
||||
if (!IsOldIdUnmapped(target)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
spv::Id hash_value = 1911;
|
||||
for (const char c : name) {
|
||||
hash_value = hash_value * 1009 + c;
|
||||
}
|
||||
|
||||
if (IsOldIdUnmapped(target)) {
|
||||
SetNewId(target, hash_value % soft_type_id_limit + first_mapped_id);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void CanonicalizeIdsPass::CanonicalizeFunctions() {
|
||||
static constexpr std::uint32_t soft_type_id_limit = 19071; // Small prime.
|
||||
static constexpr std::uint32_t first_mapped_id =
|
||||
6203; // Offset into ID space.
|
||||
// Window size for context-sensitive canonicalization values
|
||||
// Empirical best size from a single data set. TODO: Would be a good tunable.
|
||||
// We essentially perform a little convolution around each instruction,
|
||||
// to capture the flavor of nearby code, to hopefully match to similar
|
||||
// code in other modules.
|
||||
static const int32_t window_size = 2;
|
||||
|
||||
for (auto const func_id : function_ids_) {
|
||||
// Store the instructions and opcode hash values in vectors so that the
|
||||
// window of instructions can be easily accessed and avoid having to
|
||||
// recompute the hash value repeatedly in overlapping windows.
|
||||
std::vector<Instruction*> insts;
|
||||
std::vector<uint32_t> opcode_hashvals;
|
||||
auto const func = context()->GetFunction(func_id);
|
||||
func->WhileEachInst([&](Instruction* inst) {
|
||||
insts.emplace_back(inst);
|
||||
opcode_hashvals.emplace_back(HashOpCode(inst));
|
||||
return true;
|
||||
});
|
||||
|
||||
// For every instruction in the function, compute the hash value using the
|
||||
// instruction and a small window of surrounding instructions.
|
||||
assert(insts.size() < (size_t)std::numeric_limits<int32_t>::max());
|
||||
for (int32_t i = 0; i < (int32_t)insts.size(); ++i) {
|
||||
auto const inst = insts[i];
|
||||
if (!inst->HasResultId()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto const old_id = inst->result_id();
|
||||
if (!IsOldIdUnmapped(old_id)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
int32_t const lower_bound = std::max(0, i - window_size);
|
||||
int32_t const upper_bound =
|
||||
std::min((int32_t)insts.size() - 1, i + window_size);
|
||||
spv::Id hash_value = func_id * 17; // Small prime.
|
||||
// Include the hash value of the preceding instructions in the hash but
|
||||
// don't include instructions before the OpFunction.
|
||||
for (int32_t j = i - 1; j >= lower_bound; --j) {
|
||||
auto const local_inst = insts[j];
|
||||
if (local_inst->opcode() == spv::Op::OpFunction) {
|
||||
break;
|
||||
}
|
||||
|
||||
hash_value = hash_value * 30103 +
|
||||
opcode_hashvals[j]; // 30103 is a semi-arbitrary prime.
|
||||
}
|
||||
|
||||
// Include the hash value of the subsequent instructions in the hash but
|
||||
// don't include instructions past OpFunctionEnd.
|
||||
for (int32_t j = i; j <= upper_bound; ++j) {
|
||||
auto const local_inst = insts[j];
|
||||
if (local_inst->opcode() == spv::Op::OpFunctionEnd) {
|
||||
break;
|
||||
}
|
||||
|
||||
hash_value = hash_value * 30103 +
|
||||
opcode_hashvals[j]; // 30103 is a semiarbitrary prime.
|
||||
}
|
||||
|
||||
SetNewId(old_id, hash_value % soft_type_id_limit + first_mapped_id);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
spv::Id CanonicalizeIdsPass::HashOpCode(Instruction const* const inst) const {
|
||||
auto const op_code = inst->opcode();
|
||||
std::uint32_t offset = 0;
|
||||
if (op_code == spv::Op::OpExtInst) {
|
||||
// offset is literal instruction
|
||||
offset = inst->GetSingleWordOperand(3);
|
||||
}
|
||||
|
||||
return (std::uint32_t)op_code * 19 + offset; // 19 is a small prime.
|
||||
}
|
||||
|
||||
// Assign remaining IDs sequentially from remaining holes in the new ID space.
|
||||
void CanonicalizeIdsPass::CanonicalizeRemainders() {
|
||||
spv::Id next_id = 1;
|
||||
for (uint32_t old_id = 0; old_id < new_id_.size(); ++old_id) {
|
||||
if (IsOldIdUnmapped(old_id)) {
|
||||
next_id = SetNewId(old_id, next_id);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool CanonicalizeIdsPass::ApplyMap() {
|
||||
bool modified = false;
|
||||
context()->module()->ForEachInst(
|
||||
[this, &modified](Instruction* inst) {
|
||||
for (auto operand = inst->begin(); operand != inst->end(); ++operand) {
|
||||
const auto type = operand->type;
|
||||
if (spvIsIdType(type)) {
|
||||
uint32_t& id = operand->words[0];
|
||||
uint32_t const new_id = GetNewId(id);
|
||||
if (new_id == unused_) {
|
||||
continue;
|
||||
}
|
||||
|
||||
assert(new_id != unmapped_ && "new_id should not be unmapped_");
|
||||
|
||||
if (id != new_id) {
|
||||
modified = true;
|
||||
id = new_id;
|
||||
if (type == SPV_OPERAND_TYPE_RESULT_ID) {
|
||||
inst->SetResultId(new_id);
|
||||
} else if (type == SPV_OPERAND_TYPE_TYPE_ID) {
|
||||
inst->SetResultType(new_id);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
true);
|
||||
|
||||
return modified;
|
||||
}
|
||||
|
||||
spv::Id CanonicalizeIdsPass::GetBound() const {
|
||||
return context()->module()->id_bound();
|
||||
}
|
||||
|
||||
void CanonicalizeIdsPass::UpdateBound() {
|
||||
context()->module()->SetIdBound(context()->module()->ComputeIdBound());
|
||||
|
||||
context()->ResetFeatureManager();
|
||||
}
|
||||
|
||||
// Set a new ID. If the new ID is alreadly claimed, the next consecutive ID
|
||||
// will be claimed, mapped, and returned to the caller.
|
||||
spv::Id CanonicalizeIdsPass::SetNewId(spv::Id const old_id, spv::Id new_id) {
|
||||
assert(old_id < GetBound() && "don't remap an ID that is out of bounds");
|
||||
|
||||
if (old_id >= new_id_.size()) {
|
||||
new_id_.resize(old_id + 1, unused_);
|
||||
}
|
||||
|
||||
if (new_id != unmapped_ && new_id != unused_) {
|
||||
assert(!IsOldIdUnused(old_id) && "don't remap unused IDs");
|
||||
assert(IsOldIdUnmapped(old_id) && "don't remap already mapped IDs");
|
||||
|
||||
new_id = ClaimNewId(new_id);
|
||||
}
|
||||
|
||||
new_id_[old_id] = new_id;
|
||||
|
||||
return new_id;
|
||||
}
|
||||
|
||||
// Helper function for SetNewID. Claim a new ID. If the new ID is already
|
||||
// claimed, the next consecutive ID will be claimed and returned to the caller.
|
||||
spv::Id CanonicalizeIdsPass::ClaimNewId(spv::Id new_id) {
|
||||
// Return the ID if it's not taken.
|
||||
auto iter = claimed_new_ids_.find(new_id);
|
||||
if (iter != claimed_new_ids_.end()) {
|
||||
// Otherwise, search for the next unused ID using our current iterator.
|
||||
// Technically, it's a linear search across the set starting at the
|
||||
// iterator, but it's not as bad as it would appear in practice assuming the
|
||||
// hash values are well distributed.
|
||||
iter = std::adjacent_find(iter, claimed_new_ids_.end(), [](int a, int b) {
|
||||
return a + 1 != b; // Stop at the first non-consecutive pair.
|
||||
});
|
||||
if (iter != claimed_new_ids_.end()) {
|
||||
new_id =
|
||||
*iter + 1; // We need the next ID after where the search stopped.
|
||||
} else {
|
||||
new_id = *(--iter) + 1; // We reached the end so we use the next ID.
|
||||
}
|
||||
}
|
||||
|
||||
assert(!IsNewIdClaimed(new_id) &&
|
||||
"don't remap to an ID that is already claimed");
|
||||
iter = claimed_new_ids_.insert(iter, new_id);
|
||||
assert(*iter == new_id);
|
||||
|
||||
return new_id;
|
||||
}
|
||||
|
||||
std::string CanonicalizeIdsPass::IdAsString(spv::Id const id) const {
|
||||
if (id == unused_) {
|
||||
return "unused";
|
||||
} else if (id == unmapped_) {
|
||||
return "unmapped";
|
||||
} else {
|
||||
return std::to_string(id);
|
||||
}
|
||||
}
|
||||
|
||||
void CanonicalizeIdsPass::PrintNewIds() const {
|
||||
for (spv::Id id = 0; id < new_id_.size(); ++id) {
|
||||
auto const message =
|
||||
"new id[" + IdAsString(id) + "]: " + IdAsString(new_id_[id]);
|
||||
context()->consumer()(SPV_MSG_INFO, "", {0, 0, 0}, message.c_str());
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace opt
|
||||
} // namespace spvtools
|
||||
115
3rdparty/spirv-tools/source/opt/canonicalize_ids_pass.h
vendored
Normal file
115
3rdparty/spirv-tools/source/opt/canonicalize_ids_pass.h
vendored
Normal file
@@ -0,0 +1,115 @@
|
||||
// Copyright (c) 2025 LunarG Inc.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <algorithm>
|
||||
#include <map>
|
||||
#include <set>
|
||||
#include <vector>
|
||||
|
||||
#include "source/opt/pass.h"
|
||||
|
||||
namespace spvtools {
|
||||
namespace opt {
|
||||
|
||||
// The canonicalize IDs pass is an optimization to improve compression of SPIR-V
|
||||
// binary files via entropy reduction. It transforms SPIR-V to SPIR-V, remapping
|
||||
// IDs. The resulting modules have an increased ID range (IDs are not as tightly
|
||||
// packed around zero), but will compress better when multiple modules are
|
||||
// compressed together, since the compressor's dictionary can find better cross
|
||||
// module commonality. Remapping is accomplished via canonicalization. Thus,
|
||||
// modules can be compressed one at a time with no loss of quality relative to
|
||||
// operating on many modules at once.
|
||||
|
||||
// This pass should be run after most optimization passes except for
|
||||
// --strip-debug because this pass will use OpName to canonicalize IDs. i.e. Run
|
||||
// --strip-debug after this pass.
|
||||
|
||||
// This is a port of remap utility in glslang. There are great deal of magic
|
||||
// numbers that are present throughout this code. The general goal is to replace
|
||||
// the IDs with a hash value such that the distribution of IDs is deterministic
|
||||
// and minimizes collisions. The magic numbers in the glslang version were
|
||||
// chosen semi-arbitrarily and have been preserved in this port in order to
|
||||
// maintain backward compatibility.
|
||||
|
||||
class CanonicalizeIdsPass : public Pass {
|
||||
public:
|
||||
CanonicalizeIdsPass() = default;
|
||||
virtual ~CanonicalizeIdsPass() = default;
|
||||
|
||||
Pass::Status Process() override;
|
||||
|
||||
const char* name() const override { return "canonicalize-ids"; }
|
||||
|
||||
private:
|
||||
// Special values for IDs.
|
||||
static constexpr spv::Id unmapped_{spv::Id(-10000)};
|
||||
static constexpr spv::Id unused_{spv::Id(-10001)};
|
||||
|
||||
// Scans the module for IDs and sets them to unmapped_.
|
||||
void ScanIds();
|
||||
|
||||
// Functions to compute new IDs.
|
||||
void CanonicalizeTypeAndConst();
|
||||
spv::Id HashTypeAndConst(
|
||||
spv::Id const id) const; // Helper for CanonicalizeTypeAndConst.
|
||||
void CanonicalizeNames();
|
||||
void CanonicalizeFunctions();
|
||||
spv::Id HashOpCode(Instruction const* const inst)
|
||||
const; // Helper for CanonicalizeFunctions.
|
||||
void CanonicalizeRemainders();
|
||||
|
||||
// Applies the new IDs.
|
||||
bool ApplyMap();
|
||||
|
||||
// Methods to manage the bound field in header.
|
||||
spv::Id GetBound() const; // All IDs must satisfy 0 < ID < bound.
|
||||
void UpdateBound();
|
||||
|
||||
// Methods to map from old IDs to new IDs.
|
||||
spv::Id GetNewId(spv::Id const old_id) const { return new_id_[old_id]; }
|
||||
spv::Id SetNewId(spv::Id const old_id, spv::Id new_id);
|
||||
|
||||
// Methods to manage claimed IDs.
|
||||
spv::Id ClaimNewId(spv::Id new_id);
|
||||
bool IsNewIdClaimed(spv::Id const new_id) const {
|
||||
return claimed_new_ids_.find(new_id) != claimed_new_ids_.end();
|
||||
}
|
||||
|
||||
// Queries for old IDs.
|
||||
bool IsOldIdUnmapped(spv::Id const old_id) const {
|
||||
return GetNewId(old_id) == unmapped_;
|
||||
}
|
||||
bool IsOldIdUnused(spv::Id const old_id) const {
|
||||
return GetNewId(old_id) == unused_;
|
||||
}
|
||||
|
||||
// Container to map old IDs to new IDs. e.g. new_id_[old_id] = new_id
|
||||
std::vector<spv::Id> new_id_;
|
||||
|
||||
// IDs from the new ID space that have been claimed (faster than searching
|
||||
// through new_id_).
|
||||
std::set<spv::Id> claimed_new_ids_;
|
||||
|
||||
// Helper functions for printing IDs (useful for debugging).
|
||||
std::string IdAsString(spv::Id const id) const;
|
||||
void PrintNewIds() const;
|
||||
|
||||
// Containers to track IDs we want to canonicalize.
|
||||
std::vector<spv::Id> type_and_const_ids_;
|
||||
std::map<std::string, spv::Id> name_ids_;
|
||||
std::vector<spv::Id> function_ids_;
|
||||
};
|
||||
|
||||
} // namespace opt
|
||||
} // namespace spvtools
|
||||
7
3rdparty/spirv-tools/source/opt/ccp_pass.cpp
vendored
7
3rdparty/spirv-tools/source/opt/ccp_pass.cpp
vendored
@@ -360,6 +360,13 @@ void CCPPass::Initialize() {
|
||||
}
|
||||
}
|
||||
|
||||
// Mark the extended instruction imports as `kVarying`. We know they
|
||||
// will not be constants, and will be used by `OpExtInst` instructions.
|
||||
// This allows those instructions to be fully processed.
|
||||
for (const auto& inst : get_module()->ext_inst_imports()) {
|
||||
values_[inst.result_id()] = kVaryingSSAId;
|
||||
}
|
||||
|
||||
original_id_bound_ = context()->module()->IdBound();
|
||||
}
|
||||
|
||||
|
||||
@@ -1395,9 +1395,12 @@ ConstantFoldingRule FoldFMix() {
|
||||
if (base_type->AsFloat()->width() == 32) {
|
||||
one = const_mgr->GetConstant(base_type,
|
||||
utils::FloatProxy<float>(1.0f).GetWords());
|
||||
} else {
|
||||
} else if (base_type->AsFloat()->width() == 64) {
|
||||
one = const_mgr->GetConstant(base_type,
|
||||
utils::FloatProxy<double>(1.0).GetWords());
|
||||
} else {
|
||||
// We won't support folding half types.
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
if (is_vector) {
|
||||
@@ -1433,14 +1436,29 @@ const analysis::Constant* FoldMin(const analysis::Type* result_type,
|
||||
const analysis::Constant* b,
|
||||
analysis::ConstantManager*) {
|
||||
if (const analysis::Integer* int_type = result_type->AsInteger()) {
|
||||
if (int_type->width() == 32) {
|
||||
if (int_type->width() <= 32) {
|
||||
assert(
|
||||
(a->AsIntConstant() != nullptr || a->AsNullConstant() != nullptr) &&
|
||||
"Must be an integer or null constant.");
|
||||
assert(
|
||||
(b->AsIntConstant() != nullptr || b->AsNullConstant() != nullptr) &&
|
||||
"Must be an integer or null constant.");
|
||||
|
||||
if (int_type->IsSigned()) {
|
||||
int32_t va = a->GetS32();
|
||||
int32_t vb = b->GetS32();
|
||||
int32_t va = (a->AsIntConstant() != nullptr)
|
||||
? a->AsIntConstant()->GetS32BitValue()
|
||||
: 0;
|
||||
int32_t vb = (b->AsIntConstant() != nullptr)
|
||||
? b->AsIntConstant()->GetS32BitValue()
|
||||
: 0;
|
||||
return (va < vb ? a : b);
|
||||
} else {
|
||||
uint32_t va = a->GetU32();
|
||||
uint32_t vb = b->GetU32();
|
||||
uint32_t va = (a->AsIntConstant() != nullptr)
|
||||
? a->AsIntConstant()->GetU32BitValue()
|
||||
: 0;
|
||||
uint32_t vb = (b->AsIntConstant() != nullptr)
|
||||
? b->AsIntConstant()->GetU32BitValue()
|
||||
: 0;
|
||||
return (va < vb ? a : b);
|
||||
}
|
||||
} else if (int_type->width() == 64) {
|
||||
@@ -1473,14 +1491,29 @@ const analysis::Constant* FoldMax(const analysis::Type* result_type,
|
||||
const analysis::Constant* b,
|
||||
analysis::ConstantManager*) {
|
||||
if (const analysis::Integer* int_type = result_type->AsInteger()) {
|
||||
if (int_type->width() == 32) {
|
||||
if (int_type->width() <= 32) {
|
||||
assert(
|
||||
(a->AsIntConstant() != nullptr || a->AsNullConstant() != nullptr) &&
|
||||
"Must be an integer or null constant.");
|
||||
assert(
|
||||
(b->AsIntConstant() != nullptr || b->AsNullConstant() != nullptr) &&
|
||||
"Must be an integer or null constant.");
|
||||
|
||||
if (int_type->IsSigned()) {
|
||||
int32_t va = a->GetS32();
|
||||
int32_t vb = b->GetS32();
|
||||
int32_t va = (a->AsIntConstant() != nullptr)
|
||||
? a->AsIntConstant()->GetS32BitValue()
|
||||
: 0;
|
||||
int32_t vb = (b->AsIntConstant() != nullptr)
|
||||
? b->AsIntConstant()->GetS32BitValue()
|
||||
: 0;
|
||||
return (va > vb ? a : b);
|
||||
} else {
|
||||
uint32_t va = a->GetU32();
|
||||
uint32_t vb = b->GetU32();
|
||||
uint32_t va = (a->AsIntConstant() != nullptr)
|
||||
? a->AsIntConstant()->GetU32BitValue()
|
||||
: 0;
|
||||
uint32_t vb = (b->AsIntConstant() != nullptr)
|
||||
? b->AsIntConstant()->GetU32BitValue()
|
||||
: 0;
|
||||
return (va > vb ? a : b);
|
||||
}
|
||||
} else if (int_type->width() == 64) {
|
||||
|
||||
@@ -315,6 +315,7 @@ const Constant* ConstantManager::GetConstantFromInst(const Instruction* inst) {
|
||||
case spv::Op::OpConstant:
|
||||
case spv::Op::OpConstantComposite:
|
||||
case spv::Op::OpSpecConstantComposite:
|
||||
case spv::Op::OpSpecConstantCompositeReplicateEXT:
|
||||
break;
|
||||
default:
|
||||
return nullptr;
|
||||
|
||||
@@ -558,11 +558,11 @@ bool DebugInfoManager::IsDeclareVisibleToInstr(Instruction* dbg_declare,
|
||||
return false;
|
||||
}
|
||||
|
||||
bool DebugInfoManager::AddDebugValueForVariable(Instruction* scope_and_line,
|
||||
bool DebugInfoManager::AddDebugValueForVariable(Instruction* line,
|
||||
uint32_t variable_id,
|
||||
uint32_t value_id,
|
||||
Instruction* insert_pos) {
|
||||
assert(scope_and_line != nullptr);
|
||||
assert(line != nullptr);
|
||||
|
||||
auto dbg_decl_itr = var_id_to_dbg_decl_.find(variable_id);
|
||||
if (dbg_decl_itr == var_id_to_dbg_decl_.end()) return false;
|
||||
@@ -577,14 +577,15 @@ bool DebugInfoManager::AddDebugValueForVariable(Instruction* scope_and_line,
|
||||
insert_before = insert_before->NextNode();
|
||||
}
|
||||
modified |= AddDebugValueForDecl(dbg_decl_or_val, value_id, insert_before,
|
||||
scope_and_line) != nullptr;
|
||||
line) != nullptr;
|
||||
}
|
||||
return modified;
|
||||
}
|
||||
|
||||
Instruction* DebugInfoManager::AddDebugValueForDecl(
|
||||
Instruction* dbg_decl, uint32_t value_id, Instruction* insert_before,
|
||||
Instruction* scope_and_line) {
|
||||
Instruction* DebugInfoManager::AddDebugValueForDecl(Instruction* dbg_decl,
|
||||
uint32_t value_id,
|
||||
Instruction* insert_before,
|
||||
Instruction* line) {
|
||||
if (dbg_decl == nullptr || !IsDebugDeclare(dbg_decl)) return nullptr;
|
||||
|
||||
std::unique_ptr<Instruction> dbg_val(dbg_decl->Clone(context()));
|
||||
@@ -593,7 +594,7 @@ Instruction* DebugInfoManager::AddDebugValueForDecl(
|
||||
dbg_val->SetOperand(kDebugDeclareOperandVariableIndex, {value_id});
|
||||
dbg_val->SetOperand(kDebugValueOperandExpressionIndex,
|
||||
{GetEmptyDebugExpression()->result_id()});
|
||||
dbg_val->UpdateDebugInfoFrom(scope_and_line);
|
||||
dbg_val->UpdateDebugInfoFrom(dbg_decl, line);
|
||||
|
||||
auto* added_dbg_val = insert_before->InsertBefore(std::move(dbg_val));
|
||||
AnalyzeDebugInst(added_dbg_val);
|
||||
|
||||
@@ -143,22 +143,21 @@ class DebugInfoManager {
|
||||
bool KillDebugDeclares(uint32_t variable_id);
|
||||
|
||||
// Generates a DebugValue instruction with value |value_id| for every local
|
||||
// variable that is in the scope of |scope_and_line| and whose memory is
|
||||
// |variable_id| and inserts it after the instruction |insert_pos|.
|
||||
// variable that is in the scope of |line| and whose memory is |variable_id|
|
||||
// and inserts it after the instruction |insert_pos|.
|
||||
// Returns whether a DebugValue is added or not.
|
||||
bool AddDebugValueForVariable(Instruction* scope_and_line,
|
||||
uint32_t variable_id, uint32_t value_id,
|
||||
Instruction* insert_pos);
|
||||
bool AddDebugValueForVariable(Instruction* line, uint32_t variable_id,
|
||||
uint32_t value_id, Instruction* insert_pos);
|
||||
|
||||
// Creates a DebugValue for DebugDeclare |dbg_decl| and inserts it before
|
||||
// |insert_before|. The new DebugValue has the same line and scope as
|
||||
// |scope_and_line|, or no scope and line information if |scope_and_line|
|
||||
// is nullptr. The new DebugValue has the same operands as DebugDeclare
|
||||
// but it uses |value_id| for the value. Returns the created DebugValue,
|
||||
// |insert_before|. The new DebugValue has the same line as |line} and the
|
||||
// same scope as |dbg_decl|. The new DebugValue has the same operands as
|
||||
// DebugDeclare but it uses |value_id| for the value. Returns the created
|
||||
// DebugValue,
|
||||
// or nullptr if fails to create one.
|
||||
Instruction* AddDebugValueForDecl(Instruction* dbg_decl, uint32_t value_id,
|
||||
Instruction* insert_before,
|
||||
Instruction* scope_and_line);
|
||||
Instruction* line);
|
||||
|
||||
// Erases |instr| from data structures of this class.
|
||||
void ClearDebugInfo(Instruction* instr);
|
||||
|
||||
27
3rdparty/spirv-tools/source/opt/desc_sroa.cpp
vendored
27
3rdparty/spirv-tools/source/opt/desc_sroa.cpp
vendored
@@ -58,7 +58,7 @@ bool DescriptorScalarReplacement::ReplaceCandidate(Instruction* var) {
|
||||
std::vector<Instruction*> access_chain_work_list;
|
||||
std::vector<Instruction*> load_work_list;
|
||||
std::vector<Instruction*> entry_point_work_list;
|
||||
bool failed = !get_def_use_mgr()->WhileEachUser(
|
||||
bool ok = get_def_use_mgr()->WhileEachUser(
|
||||
var->result_id(), [this, &access_chain_work_list, &load_work_list,
|
||||
&entry_point_work_list](Instruction* use) {
|
||||
if (use->opcode() == spv::Op::OpName) {
|
||||
@@ -88,7 +88,7 @@ bool DescriptorScalarReplacement::ReplaceCandidate(Instruction* var) {
|
||||
return true;
|
||||
});
|
||||
|
||||
if (failed) {
|
||||
if (!ok) {
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -128,6 +128,9 @@ bool DescriptorScalarReplacement::ReplaceAccessChain(Instruction* var,
|
||||
|
||||
uint32_t idx = const_index->GetU32();
|
||||
uint32_t replacement_var = GetReplacementVariable(var, idx);
|
||||
if (replacement_var == 0) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (use->NumInOperands() == 2) {
|
||||
// We are not indexing into the replacement variable. We can replaces the
|
||||
@@ -186,8 +189,11 @@ bool DescriptorScalarReplacement::ReplaceEntryPoint(Instruction* var,
|
||||
uint32_t num_replacement_vars =
|
||||
descsroautil::GetNumberOfElementsForArrayOrStruct(context(), var);
|
||||
for (uint32_t i = 0; i < num_replacement_vars; i++) {
|
||||
new_operands.push_back(
|
||||
{SPV_OPERAND_TYPE_ID, {GetReplacementVariable(var, i)}});
|
||||
uint32_t replacement_var_id = GetReplacementVariable(var, i);
|
||||
if (replacement_var_id == 0) {
|
||||
return false;
|
||||
}
|
||||
new_operands.push_back({SPV_OPERAND_TYPE_ID, {replacement_var_id}});
|
||||
}
|
||||
|
||||
use->ReplaceOperands(new_operands);
|
||||
@@ -310,7 +316,10 @@ uint32_t DescriptorScalarReplacement::CreateReplacementVariable(
|
||||
element_type_id, storage_class);
|
||||
|
||||
// Create the variable.
|
||||
uint32_t id = TakeNextId();
|
||||
uint32_t id = context()->TakeNextId();
|
||||
if (id == 0) {
|
||||
return 0;
|
||||
}
|
||||
std::unique_ptr<Instruction> variable(
|
||||
new Instruction(context(), spv::Op::OpVariable, ptr_element_type_id, id,
|
||||
std::initializer_list<Operand>{
|
||||
@@ -444,10 +453,16 @@ bool DescriptorScalarReplacement::ReplaceCompositeExtract(
|
||||
|
||||
uint32_t replacement_var =
|
||||
GetReplacementVariable(var, extract->GetSingleWordInOperand(1));
|
||||
if (replacement_var == 0) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// The result type of the OpLoad is the same as the result type of the
|
||||
// OpCompositeExtract.
|
||||
uint32_t load_id = TakeNextId();
|
||||
uint32_t load_id = context()->TakeNextId();
|
||||
if (load_id == 0) {
|
||||
return false;
|
||||
}
|
||||
std::unique_ptr<Instruction> load(
|
||||
new Instruction(context(), spv::Op::OpLoad, extract->type_id(), load_id,
|
||||
std::initializer_list<Operand>{
|
||||
|
||||
@@ -34,10 +34,13 @@ void FeatureManager::AddExtensions(Module* module) {
|
||||
}
|
||||
|
||||
void FeatureManager::AddExtension(Instruction* ext) {
|
||||
assert(ext->opcode() == spv::Op::OpExtension &&
|
||||
assert((ext->opcode() == spv::Op::OpExtension ||
|
||||
ext->opcode() == spv::Op::OpConditionalExtensionINTEL) &&
|
||||
"Expecting an extension instruction.");
|
||||
|
||||
const std::string name = ext->GetInOperand(0u).AsString();
|
||||
const uint32_t name_i =
|
||||
ext->opcode() == spv::Op::OpConditionalExtensionINTEL ? 1u : 0u;
|
||||
const std::string name = ext->GetInOperand(name_i).AsString();
|
||||
Extension extension;
|
||||
if (GetExtensionFromString(name.c_str(), &extension)) {
|
||||
extensions_.insert(extension);
|
||||
@@ -72,7 +75,10 @@ void FeatureManager::RemoveCapability(spv::Capability cap) {
|
||||
|
||||
void FeatureManager::AddCapabilities(Module* module) {
|
||||
for (Instruction& inst : module->capabilities()) {
|
||||
AddCapability(static_cast<spv::Capability>(inst.GetSingleWordInOperand(0)));
|
||||
const uint32_t i_cap =
|
||||
inst.opcode() == spv::Op::OpConditionalCapabilityINTEL ? 1 : 0;
|
||||
AddCapability(
|
||||
static_cast<spv::Capability>(inst.GetSingleWordInOperand(i_cap)));
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
// Copyright (c) 2016 Google Inc.
|
||||
// Copyright (c) 2025 Arm Ltd.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
@@ -31,21 +32,20 @@ Pass::Status FoldSpecConstantOpAndCompositePass::Process() {
|
||||
// instructions, records their values in two internal maps: id_to_const_val_
|
||||
// and const_val_to_id_ so that we can use them to infer the value of Spec
|
||||
// Constants later.
|
||||
// For Spec Constants defined with OpSpecConstantComposite instructions, if
|
||||
// all of their components are Normal Constants, they will be turned into
|
||||
// Normal Constants too. For Spec Constants defined with OpSpecConstantOp
|
||||
// instructions, we check if they only depends on Normal Constants and fold
|
||||
// them when possible. The two maps for Normal Constants: id_to_const_val_
|
||||
// and const_val_to_id_ will be updated along the traversal so that the new
|
||||
// Normal Constants generated from folding can be used to fold following Spec
|
||||
// Constants.
|
||||
// This algorithm depends on the SSA property of SPIR-V when
|
||||
// defining constants. The dependent constants must be defined before the
|
||||
// dependee constants. So a dependent Spec Constant must be defined and
|
||||
// will be processed before its dependee Spec Constant. When we encounter
|
||||
// the dependee Spec Constants, all its dependent constants must have been
|
||||
// processed and all its dependent Spec Constants should have been folded if
|
||||
// possible.
|
||||
// For Spec Constants defined with OpSpecConstantComposite or
|
||||
// OpSpecConstantCompositeReplicateEXT instructions, if all of their
|
||||
// components are Normal Constants, they will be turned into Normal Constants
|
||||
// too. For Spec Constants defined with OpSpecConstantOp instructions, we
|
||||
// check if they only depends on Normal Constants and fold them when possible.
|
||||
// The two maps for Normal Constants: id_to_const_val_ and const_val_to_id_
|
||||
// will be updated along the traversal so that the new Normal Constants
|
||||
// generated from folding can be used to fold following Spec Constants. This
|
||||
// algorithm depends on the SSA property of SPIR-V when defining constants.
|
||||
// The dependent constants must be defined before the dependee constants. So a
|
||||
// dependent Spec Constant must be defined and will be processed before its
|
||||
// dependee Spec Constant. When we encounter the dependee Spec Constants, all
|
||||
// its dependent constants must have been processed and all its dependent Spec
|
||||
// Constants should have been folded if possible.
|
||||
Module::inst_iterator next_inst = context()->types_values_begin();
|
||||
for (Module::inst_iterator inst_iter = next_inst;
|
||||
// Need to re-evaluate the end iterator since we may modify the list of
|
||||
@@ -54,8 +54,9 @@ Pass::Status FoldSpecConstantOpAndCompositePass::Process() {
|
||||
++next_inst;
|
||||
Instruction* inst = &*inst_iter;
|
||||
// Collect constant values of normal constants and process the
|
||||
// OpSpecConstantOp and OpSpecConstantComposite instructions if possible.
|
||||
// The constant values will be stored in analysis::Constant instances.
|
||||
// OpSpecConstantOp, OpSpecConstantComposite, and
|
||||
// OpSpecConstantCompositeReplicateEXT instructions if possible. The
|
||||
// constant values will be stored in analysis::Constant instances.
|
||||
// OpConstantSampler instruction is not collected here because it cannot be
|
||||
// used in OpSpecConstant{Composite|Op} instructions.
|
||||
// TODO(qining): If the constant or its type has decoration, we may need
|
||||
@@ -70,21 +71,29 @@ Pass::Status FoldSpecConstantOpAndCompositePass::Process() {
|
||||
case spv::Op::OpConstant:
|
||||
case spv::Op::OpConstantNull:
|
||||
case spv::Op::OpConstantComposite:
|
||||
case spv::Op::OpSpecConstantComposite: {
|
||||
case spv::Op::OpSpecConstantComposite:
|
||||
case spv::Op::OpSpecConstantCompositeReplicateEXT: {
|
||||
// A Constant instance will be created if the given instruction is a
|
||||
// Normal Constant whose value(s) are fixed. Note that for a composite
|
||||
// Spec Constant defined with OpSpecConstantComposite instruction, if
|
||||
// all of its components are Normal Constants already, the Spec
|
||||
// Constant will be turned in to a Normal Constant. In that case, a
|
||||
// Constant instance should also be created successfully and recorded
|
||||
// in the id_to_const_val_ and const_val_to_id_ mapps.
|
||||
// Spec Constant defined with OpSpecConstantComposite or
|
||||
// OpSpecConstantCompositeReplicateEXT instruction, if all of its
|
||||
// components are Normal Constants already, the Spec Constant will be
|
||||
// turned in to a Normal Constant. In that case, a Constant instance
|
||||
// should also be created successfully and recorded in the
|
||||
// id_to_const_val_ and const_val_to_id_ mapps.
|
||||
if (auto const_value = const_mgr->GetConstantFromInst(inst)) {
|
||||
// Need to replace the OpSpecConstantComposite instruction with a
|
||||
// corresponding OpConstantComposite instruction.
|
||||
// Need to replace the OpSpecConstantComposite or
|
||||
// OpSpecConstantCompositeReplicateEXT instruction with a
|
||||
// corresponding OpConstantComposite or
|
||||
// OpConstantCompositeReplicateEXT instruction.
|
||||
if (opcode == spv::Op::OpSpecConstantComposite) {
|
||||
inst->SetOpcode(spv::Op::OpConstantComposite);
|
||||
modified = true;
|
||||
}
|
||||
if (opcode == spv::Op::OpSpecConstantCompositeReplicateEXT) {
|
||||
inst->SetOpcode(spv::Op::OpConstantCompositeReplicateEXT);
|
||||
modified = true;
|
||||
}
|
||||
const_mgr->MapConstantToInst(const_value, inst);
|
||||
}
|
||||
break;
|
||||
|
||||
@@ -1998,6 +1998,15 @@ FoldingRule FMixFeedingExtract() {
|
||||
bool use_x = false;
|
||||
|
||||
assert(a_const->type()->AsFloat());
|
||||
|
||||
const analysis::Type* type =
|
||||
context->get_type_mgr()->GetType(inst->type_id());
|
||||
uint32_t width = ElementWidth(type);
|
||||
if (width != 32 && width != 64) {
|
||||
// We won't support folding half float values.
|
||||
return false;
|
||||
}
|
||||
|
||||
double element_value = a_const->GetValueAsDouble();
|
||||
if (element_value == 0.0) {
|
||||
use_x = true;
|
||||
|
||||
@@ -283,9 +283,14 @@ void GraphicsRobustAccessPass::ClampIndicesForAccessChain(
|
||||
// use 0 for %min_value).
|
||||
auto clamp_index = [&inst, type_mgr, this, &replace_index](
|
||||
uint32_t operand_index, Instruction* old_value,
|
||||
Instruction* min_value, Instruction* max_value) {
|
||||
Instruction* min_value,
|
||||
Instruction* max_value) -> spv_result_t {
|
||||
auto* clamp_inst =
|
||||
MakeSClampInst(*type_mgr, old_value, min_value, max_value, &inst);
|
||||
if (clamp_inst == nullptr) {
|
||||
Fail();
|
||||
return SPV_ERROR_INTERNAL;
|
||||
}
|
||||
return replace_index(operand_index, clamp_inst);
|
||||
};
|
||||
|
||||
@@ -304,7 +309,11 @@ void GraphicsRobustAccessPass::ClampIndicesForAccessChain(
|
||||
|
||||
if (count <= 1) {
|
||||
// Replace the index with 0.
|
||||
return replace_index(operand_index, GetValueForType(0, index_type));
|
||||
Instruction* new_value = GetValueForType(0, index_type);
|
||||
if (new_value == nullptr) {
|
||||
return Fail();
|
||||
}
|
||||
return replace_index(operand_index, new_value);
|
||||
}
|
||||
|
||||
uint64_t maxval = count - 1;
|
||||
@@ -318,8 +327,15 @@ void GraphicsRobustAccessPass::ClampIndicesForAccessChain(
|
||||
// Determine the type for |maxval|.
|
||||
uint32_t next_id = context()->module()->IdBound();
|
||||
analysis::Integer signed_type_for_query(maxval_width, true);
|
||||
auto* maxval_type =
|
||||
type_mgr->GetRegisteredType(&signed_type_for_query)->AsInteger();
|
||||
auto* maxval_type_registered =
|
||||
type_mgr->GetRegisteredType(&signed_type_for_query);
|
||||
if (maxval_type_registered == nullptr) {
|
||||
return Fail();
|
||||
}
|
||||
auto* maxval_type = maxval_type_registered->AsInteger();
|
||||
if (maxval_type == nullptr) {
|
||||
return Fail();
|
||||
}
|
||||
if (next_id != context()->module()->IdBound()) {
|
||||
module_status_.modified = true;
|
||||
}
|
||||
@@ -352,15 +368,22 @@ void GraphicsRobustAccessPass::ClampIndicesForAccessChain(
|
||||
value = int_index_constant->GetS64BitValue();
|
||||
}
|
||||
if (value < 0) {
|
||||
return replace_index(operand_index, GetValueForType(0, index_type));
|
||||
Instruction* new_value = GetValueForType(0, index_type);
|
||||
if (new_value == nullptr) {
|
||||
return Fail();
|
||||
}
|
||||
return replace_index(operand_index, new_value);
|
||||
} else if (uint64_t(value) <= maxval) {
|
||||
// Nothing to do.
|
||||
return SPV_SUCCESS;
|
||||
} else {
|
||||
// Replace with maxval.
|
||||
assert(count > 0); // Already took care of this case above.
|
||||
return replace_index(operand_index,
|
||||
GetValueForType(maxval, maxval_type));
|
||||
Instruction* new_value = GetValueForType(maxval, maxval_type);
|
||||
if (new_value == nullptr) {
|
||||
return Fail();
|
||||
}
|
||||
return replace_index(operand_index, new_value);
|
||||
}
|
||||
} else {
|
||||
// Generate a clamp instruction.
|
||||
@@ -389,6 +412,9 @@ void GraphicsRobustAccessPass::ClampIndicesForAccessChain(
|
||||
}
|
||||
index_inst = WidenInteger(index_type->IsSigned(), maxval_width,
|
||||
index_inst, &inst);
|
||||
if (index_inst == nullptr) {
|
||||
return Fail();
|
||||
}
|
||||
}
|
||||
|
||||
// Finally, clamp the index.
|
||||
@@ -438,28 +464,51 @@ void GraphicsRobustAccessPass::ClampIndicesForAccessChain(
|
||||
if (index_type->width() < target_width) {
|
||||
// Access chain indices are treated as signed integers.
|
||||
index_inst = WidenInteger(true, target_width, index_inst, &inst);
|
||||
if (index_inst == nullptr) {
|
||||
return Fail();
|
||||
}
|
||||
} else if (count_type->width() < target_width) {
|
||||
// Assume type sizes are treated as unsigned.
|
||||
count_inst = WidenInteger(false, target_width, count_inst, &inst);
|
||||
if (count_inst == nullptr) {
|
||||
return Fail();
|
||||
}
|
||||
}
|
||||
// Compute count - 1.
|
||||
// It doesn't matter if 1 is signed or unsigned.
|
||||
auto* one = GetValueForType(1, wider_type);
|
||||
auto* count_minus_1 = InsertInst(
|
||||
&inst, spv::Op::OpISub, type_mgr->GetId(wider_type), TakeNextId(),
|
||||
{{SPV_OPERAND_TYPE_ID, {count_inst->result_id()}},
|
||||
{SPV_OPERAND_TYPE_ID, {one->result_id()}}});
|
||||
if (!one) {
|
||||
return Fail();
|
||||
}
|
||||
auto* count_minus_1 =
|
||||
InsertInst(&inst, spv::Op::OpISub, type_mgr->GetId(wider_type),
|
||||
context()->TakeNextId(),
|
||||
{{SPV_OPERAND_TYPE_ID, {count_inst->result_id()}},
|
||||
{SPV_OPERAND_TYPE_ID, {one->result_id()}}});
|
||||
if (count_minus_1 == nullptr) {
|
||||
return Fail();
|
||||
}
|
||||
auto* zero = GetValueForType(0, wider_type);
|
||||
if (!zero) {
|
||||
return Fail();
|
||||
}
|
||||
// Make sure we clamp to an upper bound that is at most the signed max
|
||||
// for the target type.
|
||||
const uint64_t max_signed_value =
|
||||
((uint64_t(1) << (target_width - 1)) - 1);
|
||||
Instruction* max_signed_inst =
|
||||
GetValueForType(max_signed_value, wider_type);
|
||||
if (!max_signed_inst) {
|
||||
return Fail();
|
||||
}
|
||||
// Use unsigned-min to ensure that the result is always non-negative.
|
||||
// That ensures we satisfy the invariant for SClamp, where the "min"
|
||||
// argument we give it (zero), is no larger than the third argument.
|
||||
auto* upper_bound =
|
||||
MakeUMinInst(*type_mgr, count_minus_1,
|
||||
GetValueForType(max_signed_value, wider_type), &inst);
|
||||
MakeUMinInst(*type_mgr, count_minus_1, max_signed_inst, &inst);
|
||||
if (upper_bound == nullptr) {
|
||||
return Fail();
|
||||
}
|
||||
// Now clamp the index to this upper bound.
|
||||
return clamp_index(operand_index, index_inst, zero, upper_bound);
|
||||
}
|
||||
@@ -485,7 +534,7 @@ void GraphicsRobustAccessPass::ClampIndicesForAccessChain(
|
||||
case spv::Op::OpTypeVector: // Use component count
|
||||
{
|
||||
const uint32_t count = pointee_type->GetSingleWordOperand(2);
|
||||
clamp_to_literal_count(idx, count);
|
||||
if (clamp_to_literal_count(idx, count) != SPV_SUCCESS) return;
|
||||
pointee_type = GetDef(pointee_type->GetSingleWordOperand(1));
|
||||
} break;
|
||||
|
||||
@@ -493,7 +542,7 @@ void GraphicsRobustAccessPass::ClampIndicesForAccessChain(
|
||||
// The array length can be a spec constant, so go through the general
|
||||
// case.
|
||||
Instruction* array_len = GetDef(pointee_type->GetSingleWordOperand(2));
|
||||
clamp_to_count(idx, array_len);
|
||||
if (clamp_to_count(idx, array_len) != SPV_SUCCESS) return;
|
||||
pointee_type = GetDef(pointee_type->GetSingleWordOperand(1));
|
||||
} break;
|
||||
|
||||
@@ -537,7 +586,7 @@ void GraphicsRobustAccessPass::ClampIndicesForAccessChain(
|
||||
if (!array_len) { // We've already signaled an error.
|
||||
return;
|
||||
}
|
||||
clamp_to_count(idx, array_len);
|
||||
if (clamp_to_count(idx, array_len) != SPV_SUCCESS) return;
|
||||
if (module_status_.failed) return;
|
||||
pointee_type = GetDef(pointee_type->GetSingleWordOperand(1));
|
||||
} break;
|
||||
@@ -563,7 +612,10 @@ uint32_t GraphicsRobustAccessPass::GetGlslInsts() {
|
||||
}
|
||||
if (module_status_.glsl_insts_id == 0) {
|
||||
// Make a new import instruction.
|
||||
module_status_.glsl_insts_id = TakeNextId();
|
||||
module_status_.glsl_insts_id = context()->TakeNextId();
|
||||
if (module_status_.glsl_insts_id == 0) {
|
||||
return 0;
|
||||
}
|
||||
std::vector<uint32_t> words = spvtools::utils::MakeVector(glsl);
|
||||
auto import_inst = MakeUnique<Instruction>(
|
||||
context(), spv::Op::OpExtInstImport, 0, module_status_.glsl_insts_id,
|
||||
@@ -602,7 +654,10 @@ opt::Instruction* opt::GraphicsRobustAccessPass::WidenInteger(
|
||||
auto* type_mgr = context()->get_type_mgr();
|
||||
auto* unsigned_type = type_mgr->GetRegisteredType(&unsigned_type_for_query);
|
||||
auto type_id = context()->get_type_mgr()->GetId(unsigned_type);
|
||||
auto conversion_id = TakeNextId();
|
||||
auto conversion_id = context()->TakeNextId();
|
||||
if (conversion_id == 0) {
|
||||
return nullptr;
|
||||
}
|
||||
auto* conversion = InsertInst(
|
||||
before_inst, (sign_extend ? spv::Op::OpSConvert : spv::Op::OpUConvert),
|
||||
type_id, conversion_id, {{SPV_OPERAND_TYPE_ID, {value->result_id()}}});
|
||||
@@ -616,7 +671,13 @@ Instruction* GraphicsRobustAccessPass::MakeUMinInst(
|
||||
// the function so we force a deterministic ordering in case both of them need
|
||||
// to take a new ID.
|
||||
const uint32_t glsl_insts_id = GetGlslInsts();
|
||||
uint32_t smin_id = TakeNextId();
|
||||
if (glsl_insts_id == 0) {
|
||||
return nullptr;
|
||||
}
|
||||
uint32_t smin_id = context()->TakeNextId();
|
||||
if (smin_id == 0) {
|
||||
return nullptr;
|
||||
}
|
||||
const auto xwidth = tm.GetType(x->type_id())->AsInteger()->width();
|
||||
const auto ywidth = tm.GetType(y->type_id())->AsInteger()->width();
|
||||
assert(xwidth == ywidth);
|
||||
@@ -640,7 +701,13 @@ Instruction* GraphicsRobustAccessPass::MakeSClampInst(
|
||||
// the function so we force a deterministic ordering in case both of them need
|
||||
// to take a new ID.
|
||||
const uint32_t glsl_insts_id = GetGlslInsts();
|
||||
uint32_t clamp_id = TakeNextId();
|
||||
if (glsl_insts_id == 0) {
|
||||
return nullptr;
|
||||
}
|
||||
uint32_t clamp_id = context()->TakeNextId();
|
||||
if (clamp_id == 0) {
|
||||
return nullptr;
|
||||
}
|
||||
const auto xwidth = tm.GetType(x->type_id())->AsInteger()->width();
|
||||
const auto minwidth = tm.GetType(min->type_id())->AsInteger()->width();
|
||||
const auto maxwidth = tm.GetType(max->type_id())->AsInteger()->width();
|
||||
@@ -755,7 +822,11 @@ Instruction* GraphicsRobustAccessPass::MakeRuntimeArrayLengthInst(
|
||||
base_ptr_type->storage_class());
|
||||
|
||||
// Create the instruction and insert it.
|
||||
const auto new_access_chain_id = TakeNextId();
|
||||
const auto new_access_chain_id = context()->TakeNextId();
|
||||
if (new_access_chain_id == 0) {
|
||||
Fail();
|
||||
return nullptr;
|
||||
}
|
||||
auto* new_access_chain =
|
||||
InsertInst(current_access_chain, current_access_chain->opcode(),
|
||||
new_access_chain_type_id, new_access_chain_id, ops);
|
||||
@@ -784,7 +855,11 @@ Instruction* GraphicsRobustAccessPass::MakeRuntimeArrayLengthInst(
|
||||
uint32_t(struct_type->element_types().size() - 1);
|
||||
// Create the length-of-array instruction before the original access chain,
|
||||
// but after the generation of the pointer to the struct.
|
||||
const auto array_len_id = TakeNextId();
|
||||
const auto array_len_id = context()->TakeNextId();
|
||||
if (array_len_id == 0) {
|
||||
Fail();
|
||||
return nullptr;
|
||||
}
|
||||
analysis::Integer uint_type_for_query(32, false);
|
||||
auto* uint_type = type_mgr->GetRegisteredType(&uint_type_for_query);
|
||||
auto* array_len = InsertInst(
|
||||
@@ -935,12 +1010,18 @@ spv_result_t GraphicsRobustAccessPass::ClampCoordinateForImageTexelPointer(
|
||||
return type_mgr->GetRegisteredType(&proposed);
|
||||
}();
|
||||
|
||||
const uint32_t image_id = TakeNextId();
|
||||
const uint32_t image_id = context()->TakeNextId();
|
||||
if (image_id == 0) {
|
||||
return Fail();
|
||||
}
|
||||
auto* image =
|
||||
InsertInst(image_texel_pointer, spv::Op::OpLoad, image_type_id, image_id,
|
||||
{{SPV_OPERAND_TYPE_ID, {image_ptr->result_id()}}});
|
||||
|
||||
const uint32_t query_size_id = TakeNextId();
|
||||
const uint32_t query_size_id = context()->TakeNextId();
|
||||
if (query_size_id == 0) {
|
||||
return Fail();
|
||||
}
|
||||
auto* query_size =
|
||||
InsertInst(image_texel_pointer, spv::Op::OpImageQuerySize,
|
||||
type_mgr->GetTypeInstruction(query_size_type), query_size_id,
|
||||
@@ -968,7 +1049,10 @@ spv_result_t GraphicsRobustAccessPass::ClampCoordinateForImageTexelPointer(
|
||||
query_size_type, {component_1_id, component_1_id, component_6_id});
|
||||
auto* multiplicand_inst =
|
||||
constant_mgr->GetDefiningInstruction(multiplicand);
|
||||
const auto query_size_including_faces_id = TakeNextId();
|
||||
const auto query_size_including_faces_id = context()->TakeNextId();
|
||||
if (query_size_including_faces_id == 0) {
|
||||
return Fail();
|
||||
}
|
||||
query_size_including_faces = InsertInst(
|
||||
image_texel_pointer, spv::Op::OpIMul,
|
||||
type_mgr->GetTypeInstruction(query_size_type),
|
||||
@@ -992,7 +1076,10 @@ spv_result_t GraphicsRobustAccessPass::ClampCoordinateForImageTexelPointer(
|
||||
query_size_type,
|
||||
std::vector<uint32_t>(query_num_components, component_0_id));
|
||||
|
||||
const uint32_t query_max_including_faces_id = TakeNextId();
|
||||
const uint32_t query_max_including_faces_id = context()->TakeNextId();
|
||||
if (query_max_including_faces_id == 0) {
|
||||
return Fail();
|
||||
}
|
||||
auto* query_max_including_faces = InsertInst(
|
||||
image_texel_pointer, spv::Op::OpISub,
|
||||
type_mgr->GetTypeInstruction(query_size_type),
|
||||
@@ -1005,18 +1092,27 @@ spv_result_t GraphicsRobustAccessPass::ClampCoordinateForImageTexelPointer(
|
||||
auto* clamp_coord = MakeSClampInst(
|
||||
*type_mgr, coord, constant_mgr->GetDefiningInstruction(coordinate_0),
|
||||
query_max_including_faces, image_texel_pointer);
|
||||
if (clamp_coord == nullptr) {
|
||||
return Fail();
|
||||
}
|
||||
image_texel_pointer->SetInOperand(1, {clamp_coord->result_id()});
|
||||
|
||||
// Clamp the sample index
|
||||
if (multisampled) {
|
||||
// Get the sample count via OpImageQuerySamples
|
||||
const auto query_samples_id = TakeNextId();
|
||||
const auto query_samples_id = context()->TakeNextId();
|
||||
if (query_samples_id == 0) {
|
||||
return Fail();
|
||||
}
|
||||
auto* query_samples = InsertInst(
|
||||
image_texel_pointer, spv::Op::OpImageQuerySamples,
|
||||
constant_mgr->GetDefiningInstruction(component_0)->type_id(),
|
||||
query_samples_id, {{SPV_OPERAND_TYPE_ID, {image->result_id()}}});
|
||||
|
||||
const auto max_samples_id = TakeNextId();
|
||||
const auto max_samples_id = context()->TakeNextId();
|
||||
if (max_samples_id == 0) {
|
||||
return Fail();
|
||||
}
|
||||
auto* max_samples = InsertInst(image_texel_pointer, spv::Op::OpImageQuerySamples,
|
||||
query_samples->type_id(), max_samples_id,
|
||||
{{SPV_OPERAND_TYPE_ID, {query_samples_id}},
|
||||
@@ -1025,6 +1121,9 @@ spv_result_t GraphicsRobustAccessPass::ClampCoordinateForImageTexelPointer(
|
||||
auto* clamp_samples = MakeSClampInst(
|
||||
*type_mgr, samples, constant_mgr->GetDefiningInstruction(coordinate_0),
|
||||
max_samples, image_texel_pointer);
|
||||
if (clamp_samples == nullptr) {
|
||||
return Fail();
|
||||
}
|
||||
image_texel_pointer->SetInOperand(2, {clamp_samples->result_id()});
|
||||
|
||||
} else {
|
||||
@@ -1041,6 +1140,9 @@ spv_result_t GraphicsRobustAccessPass::ClampCoordinateForImageTexelPointer(
|
||||
opt::Instruction* GraphicsRobustAccessPass::InsertInst(
|
||||
opt::Instruction* where_inst, spv::Op opcode, uint32_t type_id,
|
||||
uint32_t result_id, const Instruction::OperandList& operands) {
|
||||
if (result_id == 0) {
|
||||
return nullptr;
|
||||
}
|
||||
module_status_.modified = true;
|
||||
auto* result = where_inst->InsertBefore(
|
||||
MakeUnique<Instruction>(context(), opcode, type_id, result_id, operands));
|
||||
|
||||
14
3rdparty/spirv-tools/source/opt/instruction.cpp
vendored
14
3rdparty/spirv-tools/source/opt/instruction.cpp
vendored
@@ -546,11 +546,13 @@ void Instruction::ClearDbgLineInsts() {
|
||||
clear_dbg_line_insts();
|
||||
}
|
||||
|
||||
void Instruction::UpdateDebugInfoFrom(const Instruction* from) {
|
||||
void Instruction::UpdateDebugInfoFrom(const Instruction* from,
|
||||
const Instruction* line) {
|
||||
if (from == nullptr) return;
|
||||
ClearDbgLineInsts();
|
||||
if (!from->dbg_line_insts().empty())
|
||||
AddDebugLine(&from->dbg_line_insts().back());
|
||||
const Instruction* fromLine = line != nullptr ? line : from;
|
||||
if (!fromLine->dbg_line_insts().empty())
|
||||
AddDebugLine(&fromLine->dbg_line_insts().back());
|
||||
SetDebugScope(from->GetDebugScope());
|
||||
if (!IsLineInst() &&
|
||||
context()->AreAnalysesValid(IRContext::kAnalysisDebugInfo)) {
|
||||
@@ -1033,6 +1035,12 @@ bool Instruction::IsOpcodeSafeToDelete() const {
|
||||
return true;
|
||||
}
|
||||
|
||||
if (IsNonSemanticInstruction() &&
|
||||
(GetShader100DebugOpcode() == NonSemanticShaderDebugInfo100DebugDeclare ||
|
||||
GetShader100DebugOpcode() == NonSemanticShaderDebugInfo100DebugValue)) {
|
||||
return true;
|
||||
}
|
||||
|
||||
switch (opcode()) {
|
||||
case spv::Op::OpDPdx:
|
||||
case spv::Op::OpDPdy:
|
||||
|
||||
@@ -338,7 +338,8 @@ class Instruction : public utils::IntrusiveNodeBase<Instruction> {
|
||||
// Updates lexical scope of DebugScope and OpLine.
|
||||
void UpdateLexicalScope(uint32_t scope);
|
||||
// Updates OpLine and DebugScope based on the information of |from|.
|
||||
void UpdateDebugInfoFrom(const Instruction* from);
|
||||
void UpdateDebugInfoFrom(const Instruction* from,
|
||||
const Instruction* line = nullptr);
|
||||
// Remove the |index|-th operand
|
||||
void RemoveOperand(uint32_t index) {
|
||||
operands_.erase(operands_.begin() + index);
|
||||
|
||||
@@ -239,28 +239,34 @@ void InterfaceVariableScalarReplacement::KillLocationAndComponentDecorations(
|
||||
});
|
||||
}
|
||||
|
||||
bool InterfaceVariableScalarReplacement::ReplaceInterfaceVariableWithScalars(
|
||||
Pass::Status
|
||||
InterfaceVariableScalarReplacement::ReplaceInterfaceVariableWithScalars(
|
||||
Instruction* interface_var, Instruction* interface_var_type,
|
||||
uint32_t location, uint32_t component, uint32_t extra_array_length) {
|
||||
NestedCompositeComponents scalar_interface_vars =
|
||||
std::optional<NestedCompositeComponents> scalar_interface_vars =
|
||||
CreateScalarInterfaceVarsForReplacement(interface_var_type,
|
||||
GetStorageClass(interface_var),
|
||||
extra_array_length);
|
||||
|
||||
AddLocationAndComponentDecorations(scalar_interface_vars, &location,
|
||||
if (!scalar_interface_vars) {
|
||||
return Status::Failure;
|
||||
}
|
||||
|
||||
AddLocationAndComponentDecorations(*scalar_interface_vars, &location,
|
||||
component);
|
||||
KillLocationAndComponentDecorations(interface_var->result_id());
|
||||
|
||||
if (!ReplaceInterfaceVarWith(interface_var, extra_array_length,
|
||||
scalar_interface_vars)) {
|
||||
return false;
|
||||
Status status = ReplaceInterfaceVarWith(interface_var, extra_array_length,
|
||||
*scalar_interface_vars);
|
||||
if (status == Status::Failure) {
|
||||
return status;
|
||||
}
|
||||
|
||||
context()->KillInst(interface_var);
|
||||
return true;
|
||||
return status;
|
||||
}
|
||||
|
||||
bool InterfaceVariableScalarReplacement::ReplaceInterfaceVarWith(
|
||||
Pass::Status InterfaceVariableScalarReplacement::ReplaceInterfaceVarWith(
|
||||
Instruction* interface_var, uint32_t extra_array_length,
|
||||
const NestedCompositeComponents& scalar_interface_vars) {
|
||||
std::vector<Instruction*> users;
|
||||
@@ -276,21 +282,24 @@ bool InterfaceVariableScalarReplacement::ReplaceInterfaceVarWith(
|
||||
// interface variable.
|
||||
for (uint32_t index = 0; index < extra_array_length; ++index) {
|
||||
std::unordered_map<Instruction*, Instruction*> loads_to_component_values;
|
||||
if (!ReplaceComponentsOfInterfaceVarWith(
|
||||
interface_var, users, scalar_interface_vars,
|
||||
interface_var_component_indices, &index,
|
||||
&loads_to_component_values,
|
||||
&loads_for_access_chain_to_composites)) {
|
||||
return false;
|
||||
Status status = ReplaceComponentsOfInterfaceVarWith(
|
||||
interface_var, users, scalar_interface_vars,
|
||||
interface_var_component_indices, &index, &loads_to_component_values,
|
||||
&loads_for_access_chain_to_composites);
|
||||
if (status == Status::Failure) {
|
||||
return Status::Failure;
|
||||
}
|
||||
AddComponentsToCompositesForLoads(loads_to_component_values,
|
||||
&loads_to_composites, 0);
|
||||
}
|
||||
} else if (!ReplaceComponentsOfInterfaceVarWith(
|
||||
interface_var, users, scalar_interface_vars,
|
||||
interface_var_component_indices, nullptr, &loads_to_composites,
|
||||
&loads_for_access_chain_to_composites)) {
|
||||
return false;
|
||||
} else {
|
||||
Status status = ReplaceComponentsOfInterfaceVarWith(
|
||||
interface_var, users, scalar_interface_vars,
|
||||
interface_var_component_indices, nullptr, &loads_to_composites,
|
||||
&loads_for_access_chain_to_composites);
|
||||
if (status == Status::Failure) {
|
||||
return Status::Failure;
|
||||
}
|
||||
}
|
||||
|
||||
ReplaceLoadWithCompositeConstruct(context(), loads_to_composites);
|
||||
@@ -298,7 +307,7 @@ bool InterfaceVariableScalarReplacement::ReplaceInterfaceVarWith(
|
||||
loads_for_access_chain_to_composites);
|
||||
|
||||
KillInstructionsAndUsers(users);
|
||||
return true;
|
||||
return Status::SuccessWithChange;
|
||||
}
|
||||
|
||||
void InterfaceVariableScalarReplacement::AddLocationAndComponentDecorations(
|
||||
@@ -318,7 +327,8 @@ void InterfaceVariableScalarReplacement::AddLocationAndComponentDecorations(
|
||||
}
|
||||
}
|
||||
|
||||
bool InterfaceVariableScalarReplacement::ReplaceComponentsOfInterfaceVarWith(
|
||||
Pass::Status
|
||||
InterfaceVariableScalarReplacement::ReplaceComponentsOfInterfaceVarWith(
|
||||
Instruction* interface_var,
|
||||
const std::vector<Instruction*>& interface_var_users,
|
||||
const NestedCompositeComponents& scalar_interface_vars,
|
||||
@@ -329,15 +339,16 @@ bool InterfaceVariableScalarReplacement::ReplaceComponentsOfInterfaceVarWith(
|
||||
loads_for_access_chain_to_composites) {
|
||||
if (!scalar_interface_vars.HasMultipleComponents()) {
|
||||
for (Instruction* interface_var_user : interface_var_users) {
|
||||
if (!ReplaceComponentOfInterfaceVarWith(
|
||||
interface_var, interface_var_user,
|
||||
scalar_interface_vars.GetComponentVariable(),
|
||||
interface_var_component_indices, extra_array_index,
|
||||
loads_to_composites, loads_for_access_chain_to_composites)) {
|
||||
return false;
|
||||
Status status = ReplaceComponentOfInterfaceVarWith(
|
||||
interface_var, interface_var_user,
|
||||
scalar_interface_vars.GetComponentVariable(),
|
||||
interface_var_component_indices, extra_array_index,
|
||||
loads_to_composites, loads_for_access_chain_to_composites);
|
||||
if (status == Status::Failure) {
|
||||
return Status::Failure;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
return Status::SuccessWithChange;
|
||||
}
|
||||
return ReplaceMultipleComponentsOfInterfaceVarWith(
|
||||
interface_var, interface_var_users, scalar_interface_vars.GetComponents(),
|
||||
@@ -345,27 +356,28 @@ bool InterfaceVariableScalarReplacement::ReplaceComponentsOfInterfaceVarWith(
|
||||
loads_for_access_chain_to_composites);
|
||||
}
|
||||
|
||||
bool InterfaceVariableScalarReplacement::
|
||||
ReplaceMultipleComponentsOfInterfaceVarWith(
|
||||
Instruction* interface_var,
|
||||
const std::vector<Instruction*>& interface_var_users,
|
||||
const std::vector<NestedCompositeComponents>& components,
|
||||
std::vector<uint32_t>& interface_var_component_indices,
|
||||
const uint32_t* extra_array_index,
|
||||
std::unordered_map<Instruction*, Instruction*>* loads_to_composites,
|
||||
std::unordered_map<Instruction*, Instruction*>*
|
||||
loads_for_access_chain_to_composites) {
|
||||
Pass::Status
|
||||
InterfaceVariableScalarReplacement::ReplaceMultipleComponentsOfInterfaceVarWith(
|
||||
Instruction* interface_var,
|
||||
const std::vector<Instruction*>& interface_var_users,
|
||||
const std::vector<NestedCompositeComponents>& components,
|
||||
std::vector<uint32_t>& interface_var_component_indices,
|
||||
const uint32_t* extra_array_index,
|
||||
std::unordered_map<Instruction*, Instruction*>* loads_to_composites,
|
||||
std::unordered_map<Instruction*, Instruction*>*
|
||||
loads_for_access_chain_to_composites) {
|
||||
for (uint32_t i = 0; i < components.size(); ++i) {
|
||||
interface_var_component_indices.push_back(i);
|
||||
std::unordered_map<Instruction*, Instruction*> loads_to_component_values;
|
||||
std::unordered_map<Instruction*, Instruction*>
|
||||
loads_for_access_chain_to_component_values;
|
||||
if (!ReplaceComponentsOfInterfaceVarWith(
|
||||
interface_var, interface_var_users, components[i],
|
||||
interface_var_component_indices, extra_array_index,
|
||||
&loads_to_component_values,
|
||||
&loads_for_access_chain_to_component_values)) {
|
||||
return false;
|
||||
Status status = ReplaceComponentsOfInterfaceVarWith(
|
||||
interface_var, interface_var_users, components[i],
|
||||
interface_var_component_indices, extra_array_index,
|
||||
&loads_to_component_values,
|
||||
&loads_for_access_chain_to_component_values);
|
||||
if (status == Status::Failure) {
|
||||
return Status::Failure;
|
||||
}
|
||||
interface_var_component_indices.pop_back();
|
||||
|
||||
@@ -378,10 +390,11 @@ bool InterfaceVariableScalarReplacement::
|
||||
AddComponentsToCompositesForLoads(loads_to_component_values,
|
||||
loads_to_composites, depth_to_component);
|
||||
}
|
||||
return true;
|
||||
return Status::SuccessWithChange;
|
||||
}
|
||||
|
||||
bool InterfaceVariableScalarReplacement::ReplaceComponentOfInterfaceVarWith(
|
||||
Pass::Status
|
||||
InterfaceVariableScalarReplacement::ReplaceComponentOfInterfaceVarWith(
|
||||
Instruction* interface_var, Instruction* interface_var_user,
|
||||
Instruction* scalar_var,
|
||||
const std::vector<uint32_t>& interface_var_component_indices,
|
||||
@@ -395,42 +408,49 @@ bool InterfaceVariableScalarReplacement::ReplaceComponentOfInterfaceVarWith(
|
||||
StoreComponentOfValueToScalarVar(value_id, interface_var_component_indices,
|
||||
scalar_var, extra_array_index,
|
||||
interface_var_user);
|
||||
return true;
|
||||
return Status::SuccessWithChange;
|
||||
}
|
||||
if (opcode == spv::Op::OpLoad) {
|
||||
Instruction* scalar_load =
|
||||
LoadScalarVar(scalar_var, extra_array_index, interface_var_user);
|
||||
if (scalar_load == nullptr) {
|
||||
return Status::Failure;
|
||||
}
|
||||
loads_to_component_values->insert({interface_var_user, scalar_load});
|
||||
return true;
|
||||
return Status::SuccessWithChange;
|
||||
}
|
||||
|
||||
// Copy OpName and annotation instructions only once. Therefore, we create
|
||||
// them only for the first element of the extra array.
|
||||
if (extra_array_index && *extra_array_index != 0) return true;
|
||||
if (extra_array_index && *extra_array_index != 0)
|
||||
return Status::SuccessWithChange;
|
||||
|
||||
if (opcode == spv::Op::OpDecorateId || opcode == spv::Op::OpDecorateString ||
|
||||
opcode == spv::Op::OpDecorate) {
|
||||
CloneAnnotationForVariable(interface_var_user, scalar_var->result_id());
|
||||
return true;
|
||||
return Status::SuccessWithChange;
|
||||
}
|
||||
|
||||
if (opcode == spv::Op::OpName) {
|
||||
std::unique_ptr<Instruction> new_inst(interface_var_user->Clone(context()));
|
||||
new_inst->SetInOperand(0, {scalar_var->result_id()});
|
||||
context()->AddDebug2Inst(std::move(new_inst));
|
||||
return true;
|
||||
return Status::SuccessWithChange;
|
||||
}
|
||||
|
||||
if (opcode == spv::Op::OpEntryPoint) {
|
||||
return ReplaceInterfaceVarInEntryPoint(interface_var, interface_var_user,
|
||||
scalar_var->result_id());
|
||||
if (ReplaceInterfaceVarInEntryPoint(interface_var, interface_var_user,
|
||||
scalar_var->result_id())) {
|
||||
return Status::SuccessWithChange;
|
||||
}
|
||||
return Status::Failure;
|
||||
}
|
||||
|
||||
if (opcode == spv::Op::OpAccessChain) {
|
||||
ReplaceAccessChainWith(interface_var_user, interface_var_component_indices,
|
||||
scalar_var,
|
||||
loads_for_access_chain_to_component_values);
|
||||
return true;
|
||||
return Status::SuccessWithChange;
|
||||
}
|
||||
|
||||
std::string message("Unhandled instruction");
|
||||
@@ -440,7 +460,7 @@ bool InterfaceVariableScalarReplacement::ReplaceComponentOfInterfaceVarWith(
|
||||
"\nfor interface variable scalar replacement\n " +
|
||||
interface_var->PrettyPrint(SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES);
|
||||
context()->consumer()(SPV_MSG_ERROR, "", {0, 0, 0}, message.c_str());
|
||||
return false;
|
||||
return Status::Failure;
|
||||
}
|
||||
|
||||
void InterfaceVariableScalarReplacement::UseBaseAccessChainForAccessChain(
|
||||
@@ -470,10 +490,14 @@ Instruction* InterfaceVariableScalarReplacement::CreateAccessChainToVar(
|
||||
uint32_t ptr_type_id =
|
||||
GetPointerType(*component_type_id, GetStorageClass(var));
|
||||
|
||||
std::unique_ptr<Instruction> new_access_chain(new Instruction(
|
||||
context(), spv::Op::OpAccessChain, ptr_type_id, TakeNextId(),
|
||||
std::initializer_list<Operand>{
|
||||
{SPV_OPERAND_TYPE_ID, {var->result_id()}}}));
|
||||
uint32_t new_id = TakeNextId();
|
||||
if (new_id == 0) {
|
||||
return nullptr;
|
||||
}
|
||||
std::unique_ptr<Instruction> new_access_chain(
|
||||
new Instruction(context(), spv::Op::OpAccessChain, ptr_type_id, new_id,
|
||||
std::initializer_list<Operand>{
|
||||
{SPV_OPERAND_TYPE_ID, {var->result_id()}}}));
|
||||
for (uint32_t index_id : index_ids) {
|
||||
new_access_chain->AddOperand({SPV_OPERAND_TYPE_ID, {index_id}});
|
||||
}
|
||||
@@ -490,12 +514,16 @@ Instruction* InterfaceVariableScalarReplacement::CreateAccessChainWithIndex(
|
||||
uint32_t ptr_type_id =
|
||||
GetPointerType(component_type_id, GetStorageClass(var));
|
||||
uint32_t index_id = context()->get_constant_mgr()->GetUIntConstId(index);
|
||||
std::unique_ptr<Instruction> new_access_chain(new Instruction(
|
||||
context(), spv::Op::OpAccessChain, ptr_type_id, TakeNextId(),
|
||||
std::initializer_list<Operand>{
|
||||
{SPV_OPERAND_TYPE_ID, {var->result_id()}},
|
||||
{SPV_OPERAND_TYPE_ID, {index_id}},
|
||||
}));
|
||||
uint32_t new_id = TakeNextId();
|
||||
if (new_id == 0) {
|
||||
return nullptr;
|
||||
}
|
||||
std::unique_ptr<Instruction> new_access_chain(
|
||||
new Instruction(context(), spv::Op::OpAccessChain, ptr_type_id, new_id,
|
||||
std::initializer_list<Operand>{
|
||||
{SPV_OPERAND_TYPE_ID, {var->result_id()}},
|
||||
{SPV_OPERAND_TYPE_ID, {index_id}},
|
||||
}));
|
||||
Instruction* inst = new_access_chain.get();
|
||||
context()->get_def_use_mgr()->AnalyzeInstDefUse(inst);
|
||||
insert_before->InsertBefore(std::move(new_access_chain));
|
||||
@@ -617,6 +645,9 @@ void InterfaceVariableScalarReplacement::StoreComponentOfValueToScalarVar(
|
||||
component_type_id = ty_mgr->GetTypeInstruction(array_type->element_type());
|
||||
ptr = CreateAccessChainWithIndex(component_type_id, scalar_var,
|
||||
*extra_array_index, insert_before);
|
||||
if (ptr == nullptr) {
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
StoreComponentOfValueTo(component_type_id, value_id, component_indices, ptr,
|
||||
@@ -635,6 +666,9 @@ Instruction* InterfaceVariableScalarReplacement::LoadScalarVar(
|
||||
component_type_id = ty_mgr->GetTypeInstruction(array_type->element_type());
|
||||
ptr = CreateAccessChainWithIndex(component_type_id, scalar_var,
|
||||
*extra_array_index, insert_before);
|
||||
if (ptr == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
return CreateLoad(component_type_id, ptr, insert_before);
|
||||
@@ -642,8 +676,12 @@ Instruction* InterfaceVariableScalarReplacement::LoadScalarVar(
|
||||
|
||||
Instruction* InterfaceVariableScalarReplacement::CreateLoad(
|
||||
uint32_t type_id, Instruction* ptr, Instruction* insert_before) {
|
||||
uint32_t new_id = TakeNextId();
|
||||
if (new_id == 0) {
|
||||
return nullptr;
|
||||
}
|
||||
std::unique_ptr<Instruction> load(
|
||||
new Instruction(context(), spv::Op::OpLoad, type_id, TakeNextId(),
|
||||
new Instruction(context(), spv::Op::OpLoad, type_id, new_id,
|
||||
std::initializer_list<Operand>{
|
||||
{SPV_OPERAND_TYPE_ID, {ptr->result_id()}}}));
|
||||
Instruction* load_inst = load.get();
|
||||
@@ -658,6 +696,9 @@ void InterfaceVariableScalarReplacement::StoreComponentOfValueTo(
|
||||
const uint32_t* extra_array_index, Instruction* insert_before) {
|
||||
std::unique_ptr<Instruction> composite_extract(CreateCompositeExtract(
|
||||
component_type_id, value_id, component_indices, extra_array_index));
|
||||
if (composite_extract == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
std::unique_ptr<Instruction> new_store(
|
||||
new Instruction(context(), spv::Op::OpStore));
|
||||
@@ -677,6 +718,9 @@ Instruction* InterfaceVariableScalarReplacement::CreateCompositeExtract(
|
||||
uint32_t type_id, uint32_t composite_id,
|
||||
const std::vector<uint32_t>& indexes, const uint32_t* extra_first_index) {
|
||||
uint32_t component_id = TakeNextId();
|
||||
if (component_id == 0) {
|
||||
return nullptr;
|
||||
}
|
||||
Instruction* composite_extract = new Instruction(
|
||||
context(), spv::Op::OpCompositeExtract, type_id, component_id,
|
||||
std::initializer_list<Operand>{{SPV_OPERAND_TYPE_ID, {composite_id}}});
|
||||
@@ -716,6 +760,9 @@ Instruction* InterfaceVariableScalarReplacement::LoadAccessChainToVar(
|
||||
if (!indexes.empty()) {
|
||||
ptr = CreateAccessChainToVar(component_type_id, var, indexes, insert_before,
|
||||
&component_type_id);
|
||||
if (ptr == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
return CreateLoad(component_type_id, ptr, insert_before);
|
||||
@@ -730,7 +777,10 @@ InterfaceVariableScalarReplacement::CreateCompositeConstructForComponentOfLoad(
|
||||
type_id = GetComponentTypeOfArrayMatrix(def_use_mgr, load->type_id(),
|
||||
depth_to_component);
|
||||
}
|
||||
uint32_t new_id = context()->TakeNextId();
|
||||
uint32_t new_id = TakeNextId();
|
||||
if (new_id == 0) {
|
||||
return nullptr;
|
||||
}
|
||||
std::unique_ptr<Instruction> new_composite_construct(new Instruction(
|
||||
context(), spv::Op::OpCompositeConstruct, type_id, new_id, {}));
|
||||
Instruction* composite_construct = new_composite_construct.get();
|
||||
@@ -767,6 +817,10 @@ void InterfaceVariableScalarReplacement::AddComponentsToCompositesForLoads(
|
||||
if (itr == loads_to_composites->end()) {
|
||||
composite_construct =
|
||||
CreateCompositeConstructForComponentOfLoad(load, depth_to_component);
|
||||
if (composite_construct == nullptr) {
|
||||
assert(false && "Could not create composite construct");
|
||||
return;
|
||||
}
|
||||
loads_to_composites->insert({load, composite_construct});
|
||||
} else {
|
||||
composite_construct = itr->second;
|
||||
@@ -795,7 +849,7 @@ uint32_t InterfaceVariableScalarReplacement::GetPointerType(
|
||||
return context()->get_type_mgr()->GetTypeInstruction(&ptr_type);
|
||||
}
|
||||
|
||||
InterfaceVariableScalarReplacement::NestedCompositeComponents
|
||||
std::optional<InterfaceVariableScalarReplacement::NestedCompositeComponents>
|
||||
InterfaceVariableScalarReplacement::CreateScalarInterfaceVarsForArray(
|
||||
Instruction* interface_var_type, spv::StorageClass storage_class,
|
||||
uint32_t extra_array_length) {
|
||||
@@ -807,16 +861,19 @@ InterfaceVariableScalarReplacement::CreateScalarInterfaceVarsForArray(
|
||||
|
||||
NestedCompositeComponents scalar_vars;
|
||||
while (array_length > 0) {
|
||||
NestedCompositeComponents scalar_vars_for_element =
|
||||
std::optional<NestedCompositeComponents> scalar_vars_for_element =
|
||||
CreateScalarInterfaceVarsForReplacement(elem_type, storage_class,
|
||||
extra_array_length);
|
||||
scalar_vars.AddComponent(scalar_vars_for_element);
|
||||
if (!scalar_vars_for_element) {
|
||||
return std::nullopt;
|
||||
}
|
||||
scalar_vars.AddComponent(*scalar_vars_for_element);
|
||||
--array_length;
|
||||
}
|
||||
return scalar_vars;
|
||||
}
|
||||
|
||||
InterfaceVariableScalarReplacement::NestedCompositeComponents
|
||||
std::optional<InterfaceVariableScalarReplacement::NestedCompositeComponents>
|
||||
InterfaceVariableScalarReplacement::CreateScalarInterfaceVarsForMatrix(
|
||||
Instruction* interface_var_type, spv::StorageClass storage_class,
|
||||
uint32_t extra_array_length) {
|
||||
@@ -830,16 +887,19 @@ InterfaceVariableScalarReplacement::CreateScalarInterfaceVarsForMatrix(
|
||||
|
||||
NestedCompositeComponents scalar_vars;
|
||||
while (column_count > 0) {
|
||||
NestedCompositeComponents scalar_vars_for_column =
|
||||
std::optional<NestedCompositeComponents> scalar_vars_for_column =
|
||||
CreateScalarInterfaceVarsForReplacement(column_type, storage_class,
|
||||
extra_array_length);
|
||||
scalar_vars.AddComponent(scalar_vars_for_column);
|
||||
if (!scalar_vars_for_column) {
|
||||
return std::nullopt;
|
||||
}
|
||||
scalar_vars.AddComponent(*scalar_vars_for_column);
|
||||
--column_count;
|
||||
}
|
||||
return scalar_vars;
|
||||
}
|
||||
|
||||
InterfaceVariableScalarReplacement::NestedCompositeComponents
|
||||
std::optional<InterfaceVariableScalarReplacement::NestedCompositeComponents>
|
||||
InterfaceVariableScalarReplacement::CreateScalarInterfaceVarsForReplacement(
|
||||
Instruction* interface_var_type, spv::StorageClass storage_class,
|
||||
uint32_t extra_array_length) {
|
||||
@@ -864,6 +924,9 @@ InterfaceVariableScalarReplacement::CreateScalarInterfaceVarsForReplacement(
|
||||
uint32_t ptr_type_id =
|
||||
context()->get_type_mgr()->FindPointerToType(type_id, storage_class);
|
||||
uint32_t id = TakeNextId();
|
||||
if (id == 0) {
|
||||
return std::nullopt;
|
||||
}
|
||||
std::unique_ptr<Instruction> variable(
|
||||
new Instruction(context(), spv::Op::OpVariable, ptr_type_id, id,
|
||||
std::initializer_list<Operand>{
|
||||
@@ -953,9 +1016,9 @@ InterfaceVariableScalarReplacement::ReplaceInterfaceVarsWithScalars(
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!ReplaceInterfaceVariableWithScalars(interface_var, interface_var_type,
|
||||
location, component,
|
||||
extra_array_length)) {
|
||||
if (ReplaceInterfaceVariableWithScalars(
|
||||
interface_var, interface_var_type, location, component,
|
||||
extra_array_length) == Pass::Status::Failure) {
|
||||
return Pass::Status::Failure;
|
||||
}
|
||||
status = Pass::Status::SuccessWithChange;
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
#ifndef SOURCE_OPT_INTERFACE_VAR_SROA_H_
|
||||
#define SOURCE_OPT_INTERFACE_VAR_SROA_H_
|
||||
|
||||
#include <optional>
|
||||
#include <unordered_set>
|
||||
|
||||
#include "source/opt/pass.h"
|
||||
@@ -100,25 +101,26 @@ class InterfaceVariableScalarReplacement : public Pass {
|
||||
// If |extra_array_length| is 0, it means |interface_var| has a Patch
|
||||
// decoration. Otherwise, |extra_array_length| denotes the length of the extra
|
||||
// array of |interface_var|.
|
||||
bool ReplaceInterfaceVariableWithScalars(Instruction* interface_var,
|
||||
Instruction* interface_var_type,
|
||||
uint32_t location,
|
||||
uint32_t component,
|
||||
uint32_t extra_array_length);
|
||||
Status ReplaceInterfaceVariableWithScalars(Instruction* interface_var,
|
||||
Instruction* interface_var_type,
|
||||
uint32_t location,
|
||||
uint32_t component,
|
||||
uint32_t extra_array_length);
|
||||
|
||||
// Creates scalar variables with the storage classe |storage_class| to replace
|
||||
// an interface variable whose type is |interface_var_type|. If
|
||||
// |extra_array_length| is not zero, adds the extra arrayness to the created
|
||||
// scalar variables.
|
||||
NestedCompositeComponents CreateScalarInterfaceVarsForReplacement(
|
||||
Instruction* interface_var_type, spv::StorageClass storage_class,
|
||||
uint32_t extra_array_length);
|
||||
std::optional<NestedCompositeComponents>
|
||||
CreateScalarInterfaceVarsForReplacement(Instruction* interface_var_type,
|
||||
spv::StorageClass storage_class,
|
||||
uint32_t extra_array_length);
|
||||
|
||||
// Creates scalar variables with the storage classe |storage_class| to replace
|
||||
// the interface variable whose type is OpTypeArray |interface_var_type| with.
|
||||
// If |extra_array_length| is not zero, adds the extra arrayness to all the
|
||||
// scalar variables.
|
||||
NestedCompositeComponents CreateScalarInterfaceVarsForArray(
|
||||
std::optional<NestedCompositeComponents> CreateScalarInterfaceVarsForArray(
|
||||
Instruction* interface_var_type, spv::StorageClass storage_class,
|
||||
uint32_t extra_array_length);
|
||||
|
||||
@@ -126,7 +128,7 @@ class InterfaceVariableScalarReplacement : public Pass {
|
||||
// the interface variable whose type is OpTypeMatrix |interface_var_type|
|
||||
// with. If |extra_array_length| is not zero, adds the extra arrayness to all
|
||||
// the scalar variables.
|
||||
NestedCompositeComponents CreateScalarInterfaceVarsForMatrix(
|
||||
std::optional<NestedCompositeComponents> CreateScalarInterfaceVarsForMatrix(
|
||||
Instruction* interface_var_type, spv::StorageClass storage_class,
|
||||
uint32_t extra_array_length);
|
||||
|
||||
@@ -142,7 +144,7 @@ class InterfaceVariableScalarReplacement : public Pass {
|
||||
// |extra_arrayness| is the extra arrayness of the interface variable.
|
||||
// |scalar_interface_vars| contains the nested variables to replace the
|
||||
// interface variable with.
|
||||
bool ReplaceInterfaceVarWith(
|
||||
Status ReplaceInterfaceVarWith(
|
||||
Instruction* interface_var, uint32_t extra_arrayness,
|
||||
const NestedCompositeComponents& scalar_interface_vars);
|
||||
|
||||
@@ -155,7 +157,7 @@ class InterfaceVariableScalarReplacement : public Pass {
|
||||
// construct instructions to be replaced with load instructions of access
|
||||
// chain instructions in |interface_var_users| via
|
||||
// |loads_for_access_chain_to_composites|.
|
||||
bool ReplaceComponentsOfInterfaceVarWith(
|
||||
Status ReplaceComponentsOfInterfaceVarWith(
|
||||
Instruction* interface_var,
|
||||
const std::vector<Instruction*>& interface_var_users,
|
||||
const NestedCompositeComponents& scalar_interface_vars,
|
||||
@@ -174,7 +176,7 @@ class InterfaceVariableScalarReplacement : public Pass {
|
||||
// via |loads_to_composites|. Returns composite construct instructions to be
|
||||
// replaced with load instructions of access chain instructions in
|
||||
// |interface_var_users| via |loads_for_access_chain_to_composites|.
|
||||
bool ReplaceMultipleComponentsOfInterfaceVarWith(
|
||||
Status ReplaceMultipleComponentsOfInterfaceVarWith(
|
||||
Instruction* interface_var,
|
||||
const std::vector<Instruction*>& interface_var_users,
|
||||
const std::vector<NestedCompositeComponents>& components,
|
||||
@@ -192,7 +194,7 @@ class InterfaceVariableScalarReplacement : public Pass {
|
||||
// |loads_to_component_values|. If |interface_var_user| is an access chain,
|
||||
// returns the component value for loads of |interface_var_user| via
|
||||
// |loads_for_access_chain_to_component_values|.
|
||||
bool ReplaceComponentOfInterfaceVarWith(
|
||||
Status ReplaceComponentOfInterfaceVarWith(
|
||||
Instruction* interface_var, Instruction* interface_var_user,
|
||||
Instruction* scalar_var,
|
||||
const std::vector<uint32_t>& interface_var_component_indices,
|
||||
@@ -389,6 +391,9 @@ class InterfaceVariableScalarReplacement : public Pass {
|
||||
// A set of interface variables without the extra arrayness for any of the
|
||||
// entry points.
|
||||
std::unordered_set<Instruction*> vars_without_extra_arrayness;
|
||||
|
||||
// Returns the next available id, or 0 if the id overflows.
|
||||
uint32_t TakeNextId() { return context()->TakeNextId(); }
|
||||
};
|
||||
|
||||
} // namespace opt
|
||||
|
||||
@@ -294,8 +294,12 @@ bool InvocationInterlockPlacementPass::removeUnneededInstructions(
|
||||
BasicBlock* InvocationInterlockPlacementPass::splitEdge(BasicBlock* block,
|
||||
uint32_t succ_id) {
|
||||
// Create a new block to replace the critical edge.
|
||||
uint32_t new_id = context()->TakeNextId();
|
||||
if (new_id == 0) {
|
||||
return nullptr;
|
||||
}
|
||||
auto new_succ_temp = MakeUnique<BasicBlock>(
|
||||
MakeUnique<Instruction>(context(), spv::Op::OpLabel, 0, TakeNextId(),
|
||||
MakeUnique<Instruction>(context(), spv::Op::OpLabel, 0, new_id,
|
||||
std::initializer_list<Operand>{}));
|
||||
auto* new_succ = new_succ_temp.get();
|
||||
|
||||
@@ -325,7 +329,7 @@ BasicBlock* InvocationInterlockPlacementPass::splitEdge(BasicBlock* block,
|
||||
return new_succ;
|
||||
}
|
||||
|
||||
bool InvocationInterlockPlacementPass::placeInstructionsForEdge(
|
||||
Pass::Status InvocationInterlockPlacementPass::placeInstructionsForEdge(
|
||||
BasicBlock* block, uint32_t next_id, BlockSet& inside,
|
||||
BlockSet& previous_inside, spv::Op opcode, bool reverse_cfg) {
|
||||
bool modified = false;
|
||||
@@ -372,31 +376,45 @@ bool InvocationInterlockPlacementPass::placeInstructionsForEdge(
|
||||
new_branch = splitEdge(cfg()->block(next_id), block->id());
|
||||
}
|
||||
|
||||
if (!new_branch) {
|
||||
return Status::Failure;
|
||||
}
|
||||
|
||||
auto inst = new Instruction(context(), opcode);
|
||||
inst->InsertBefore(&*new_branch->tail());
|
||||
}
|
||||
}
|
||||
|
||||
return modified;
|
||||
return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange;
|
||||
}
|
||||
|
||||
bool InvocationInterlockPlacementPass::placeInstructions(BasicBlock* block) {
|
||||
bool modified = false;
|
||||
Pass::Status InvocationInterlockPlacementPass::placeInstructions(
|
||||
BasicBlock* block) {
|
||||
Status status = Status::SuccessWithoutChange;
|
||||
|
||||
block->ForEachSuccessorLabel([this, block, &modified](uint32_t succ_id) {
|
||||
modified |= placeInstructionsForEdge(
|
||||
block->ForEachSuccessorLabel([this, block, &status](uint32_t succ_id) {
|
||||
if (status == Status::Failure) {
|
||||
return;
|
||||
}
|
||||
Status edge_status = placeInstructionsForEdge(
|
||||
block, succ_id, after_begin_, predecessors_after_begin_,
|
||||
spv::Op::OpBeginInvocationInterlockEXT, /* reverse_cfg= */ true);
|
||||
modified |= placeInstructionsForEdge(cfg()->block(succ_id), block->id(),
|
||||
before_end_, successors_before_end_,
|
||||
spv::Op::OpEndInvocationInterlockEXT,
|
||||
/* reverse_cfg= */ false);
|
||||
status = CombineStatus(status, edge_status);
|
||||
if (status == Status::Failure) {
|
||||
return;
|
||||
}
|
||||
|
||||
edge_status = placeInstructionsForEdge(cfg()->block(succ_id), block->id(),
|
||||
before_end_, successors_before_end_,
|
||||
spv::Op::OpEndInvocationInterlockEXT,
|
||||
/* reverse_cfg= */ false);
|
||||
status = CombineStatus(status, edge_status);
|
||||
});
|
||||
|
||||
return modified;
|
||||
return status;
|
||||
}
|
||||
|
||||
bool InvocationInterlockPlacementPass::processFragmentShaderEntry(
|
||||
Pass::Status InvocationInterlockPlacementPass::processFragmentShaderEntry(
|
||||
Function* entry_func) {
|
||||
bool modified = false;
|
||||
|
||||
@@ -417,9 +435,15 @@ bool InvocationInterlockPlacementPass::processFragmentShaderEntry(
|
||||
|
||||
for (BasicBlock* block : original_blocks) {
|
||||
modified |= removeUnneededInstructions(block);
|
||||
modified |= placeInstructions(block);
|
||||
Status place_status = placeInstructions(block);
|
||||
if (place_status == Status::Failure) {
|
||||
return Status::Failure;
|
||||
}
|
||||
if (place_status == Status::SuccessWithChange) {
|
||||
modified = true;
|
||||
}
|
||||
}
|
||||
return modified;
|
||||
return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange;
|
||||
}
|
||||
|
||||
bool InvocationInterlockPlacementPass::isFragmentShaderInterlockEnabled() {
|
||||
@@ -452,7 +476,7 @@ Pass::Status InvocationInterlockPlacementPass::Process() {
|
||||
return Status::SuccessWithoutChange;
|
||||
}
|
||||
|
||||
bool modified = false;
|
||||
Status status = Status::SuccessWithoutChange;
|
||||
|
||||
std::unordered_set<Function*> entry_points;
|
||||
for (Instruction& entry_inst : context()->module()->entry_points()) {
|
||||
@@ -466,7 +490,9 @@ Pass::Status InvocationInterlockPlacementPass::Process() {
|
||||
Function* func = &*fi;
|
||||
recordBeginOrEndInFunction(func);
|
||||
if (!entry_points.count(func) && extracted_functions_.count(func)) {
|
||||
modified |= removeBeginAndEndInstructionsFromFunction(func);
|
||||
if (removeBeginAndEndInstructionsFromFunction(func)) {
|
||||
status = Status::SuccessWithChange;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -482,11 +508,14 @@ Pass::Status InvocationInterlockPlacementPass::Process() {
|
||||
continue;
|
||||
}
|
||||
|
||||
modified |= processFragmentShaderEntry(entry_func);
|
||||
Status frag_status = processFragmentShaderEntry(entry_func);
|
||||
if (frag_status == Status::Failure) {
|
||||
return Status::Failure;
|
||||
}
|
||||
status = CombineStatus(status, frag_status);
|
||||
}
|
||||
|
||||
return modified ? Pass::Status::SuccessWithChange
|
||||
: Pass::Status::SuccessWithoutChange;
|
||||
return status;
|
||||
}
|
||||
|
||||
} // namespace opt
|
||||
|
||||
@@ -120,14 +120,14 @@ class InvocationInterlockPlacementPass : public Pass {
|
||||
// For the edge from block to next_id, places a begin or end instruction on
|
||||
// the edge, based on the direction we are walking the CFG, specified in
|
||||
// reverse_cfg.
|
||||
bool placeInstructionsForEdge(BasicBlock* block, uint32_t next_id,
|
||||
BlockSet& inside, BlockSet& previous_inside,
|
||||
spv::Op opcode, bool reverse_cfg);
|
||||
Status placeInstructionsForEdge(BasicBlock* block, uint32_t next_id,
|
||||
BlockSet& inside, BlockSet& previous_inside,
|
||||
spv::Op opcode, bool reverse_cfg);
|
||||
// Calls placeInstructionsForEdge for each edge in block.
|
||||
bool placeInstructions(BasicBlock* block);
|
||||
Status placeInstructions(BasicBlock* block);
|
||||
|
||||
// Processes a single fragment shader entry function.
|
||||
bool processFragmentShaderEntry(Function* entry_func);
|
||||
Status processFragmentShaderEntry(Function* entry_func);
|
||||
|
||||
// Returns whether the module has the SPV_EXT_fragment_shader_interlock
|
||||
// extension and one of the FragmentShader*InterlockEXT capabilities.
|
||||
|
||||
@@ -201,7 +201,9 @@ Instruction* IRContext::KillInst(Instruction* inst) {
|
||||
constant_mgr_->RemoveId(inst->result_id());
|
||||
}
|
||||
if (inst->opcode() == spv::Op::OpCapability ||
|
||||
inst->opcode() == spv::Op::OpExtension) {
|
||||
inst->opcode() == spv::Op::OpConditionalCapabilityINTEL ||
|
||||
inst->opcode() == spv::Op::OpExtension ||
|
||||
inst->opcode() == spv::Op::OpConditionalExtensionINTEL) {
|
||||
// We reset the feature manager, instead of updating it, because it is just
|
||||
// as much work. We would have to remove all capabilities implied by this
|
||||
// capability that are not also implied by the remaining OpCapability
|
||||
@@ -382,6 +384,7 @@ bool IRContext::IsConsistent() {
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
if (AreAnalysesValid(kAnalysisIdToFuncMapping)) {
|
||||
for (auto& fn : *module_) {
|
||||
if (id_to_func_[fn.result_id()] != &fn) {
|
||||
@@ -398,8 +401,9 @@ bool IRContext::IsConsistent() {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}))
|
||||
})) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -181,9 +181,11 @@ bool IrLoader::AddInstruction(const spv_parsed_instruction_t* inst) {
|
||||
} else {
|
||||
if (function_ == nullptr) { // Outside function definition
|
||||
SPIRV_ASSERT(consumer_, block_ == nullptr);
|
||||
if (opcode == spv::Op::OpCapability) {
|
||||
if (opcode == spv::Op::OpCapability ||
|
||||
opcode == spv::Op::OpConditionalCapabilityINTEL) {
|
||||
module_->AddCapability(std::move(spv_inst));
|
||||
} else if (opcode == spv::Op::OpExtension) {
|
||||
} else if (opcode == spv::Op::OpExtension ||
|
||||
opcode == spv::Op::OpConditionalExtensionINTEL) {
|
||||
module_->AddExtension(std::move(spv_inst));
|
||||
} else if (opcode == spv::Op::OpExtInstImport) {
|
||||
module_->AddExtInstImport(std::move(spv_inst));
|
||||
|
||||
14
3rdparty/spirv-tools/source/opt/loop_fission.cpp
vendored
14
3rdparty/spirv-tools/source/opt/loop_fission.cpp
vendored
@@ -362,14 +362,19 @@ Loop* LoopFissionImpl::SplitLoop() {
|
||||
LoopUtils util{context_, loop_};
|
||||
LoopUtils::LoopCloningResult clone_results;
|
||||
Loop* cloned_loop = util.CloneAndAttachLoopToHeader(&clone_results);
|
||||
if (!cloned_loop) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Update the OpLoopMerge in the cloned loop.
|
||||
cloned_loop->UpdateLoopMergeInst();
|
||||
|
||||
// Add the loop_ to the module.
|
||||
// TODO(1841): Handle failure to create pre-header.
|
||||
Function::iterator it =
|
||||
util.GetFunction()->FindBlock(loop_->GetOrCreatePreHeaderBlock()->id());
|
||||
BasicBlock* pre_header = loop_->GetOrCreatePreHeaderBlock();
|
||||
if (!pre_header) {
|
||||
return nullptr;
|
||||
}
|
||||
Function::iterator it = util.GetFunction()->FindBlock(pre_header->id());
|
||||
util.GetFunction()->AddBasicBlocks(clone_results.cloned_bb_.begin(),
|
||||
clone_results.cloned_bb_.end(), ++it);
|
||||
loop_->SetPreHeaderBlock(cloned_loop->GetMergeBlock());
|
||||
@@ -478,6 +483,9 @@ Pass::Status LoopFissionPass::Process() {
|
||||
|
||||
if (impl.CanPerformSplit()) {
|
||||
Loop* second_loop = impl.SplitLoop();
|
||||
if (!second_loop) {
|
||||
return Status::Failure;
|
||||
}
|
||||
changed = true;
|
||||
context()->InvalidateAnalysesExceptFor(
|
||||
IRContext::kAnalysisLoopAnalysis);
|
||||
|
||||
208
3rdparty/spirv-tools/source/opt/loop_peeling.cpp
vendored
208
3rdparty/spirv-tools/source/opt/loop_peeling.cpp
vendored
@@ -45,7 +45,7 @@ void GetBlocksInPath(uint32_t block, uint32_t entry,
|
||||
|
||||
size_t LoopPeelingPass::code_grow_threshold_ = 1000;
|
||||
|
||||
void LoopPeeling::DuplicateAndConnectLoop(
|
||||
bool LoopPeeling::DuplicateAndConnectLoop(
|
||||
LoopUtils::LoopCloningResult* clone_results) {
|
||||
CFG& cfg = *context_->cfg();
|
||||
analysis::DefUseManager* def_use_mgr = context_->get_def_use_mgr();
|
||||
@@ -53,12 +53,17 @@ void LoopPeeling::DuplicateAndConnectLoop(
|
||||
assert(CanPeelLoop() && "Cannot peel loop!");
|
||||
|
||||
std::vector<BasicBlock*> ordered_loop_blocks;
|
||||
// TODO(1841): Handle failure to create pre-header.
|
||||
BasicBlock* pre_header = loop_->GetOrCreatePreHeaderBlock();
|
||||
if (!pre_header) {
|
||||
return false;
|
||||
}
|
||||
|
||||
loop_->ComputeLoopStructuredOrder(&ordered_loop_blocks);
|
||||
|
||||
cloned_loop_ = loop_utils_.CloneLoop(clone_results, ordered_loop_blocks);
|
||||
if (!cloned_loop_) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Add the basic block to the function.
|
||||
Function::iterator it =
|
||||
@@ -146,17 +151,21 @@ void LoopPeeling::DuplicateAndConnectLoop(
|
||||
|
||||
// Force the creation of a new preheader for the original loop and set it as
|
||||
// the merge block for the cloned loop.
|
||||
// TODO(1841): Handle failure to create pre-header.
|
||||
cloned_loop_->SetMergeBlock(loop_->GetOrCreatePreHeaderBlock());
|
||||
BasicBlock* new_pre_header = loop_->GetOrCreatePreHeaderBlock();
|
||||
if (!new_pre_header) {
|
||||
return false;
|
||||
}
|
||||
cloned_loop_->SetMergeBlock(new_pre_header);
|
||||
return true;
|
||||
}
|
||||
|
||||
void LoopPeeling::InsertCanonicalInductionVariable(
|
||||
bool LoopPeeling::InsertCanonicalInductionVariable(
|
||||
LoopUtils::LoopCloningResult* clone_results) {
|
||||
if (original_loop_canonical_induction_variable_) {
|
||||
canonical_induction_variable_ =
|
||||
context_->get_def_use_mgr()->GetDef(clone_results->value_map_.at(
|
||||
original_loop_canonical_induction_variable_->result_id()));
|
||||
return;
|
||||
return true;
|
||||
}
|
||||
|
||||
BasicBlock::iterator insert_point = GetClonedLoop()->GetLatchBlock()->tail();
|
||||
@@ -168,19 +177,25 @@ void LoopPeeling::InsertCanonicalInductionVariable(
|
||||
IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
|
||||
Instruction* uint_1_cst =
|
||||
builder.GetIntConstant<uint32_t>(1, int_type_->IsSigned());
|
||||
if (!uint_1_cst) return false;
|
||||
// Create the increment.
|
||||
// Note that we do "1 + 1" here, one of the operand should the phi
|
||||
// value but we don't have it yet. The operand will be set latter.
|
||||
Instruction* iv_inc = builder.AddIAdd(
|
||||
uint_1_cst->type_id(), uint_1_cst->result_id(), uint_1_cst->result_id());
|
||||
if (!iv_inc) return false;
|
||||
|
||||
builder.SetInsertPoint(&*GetClonedLoop()->GetHeaderBlock()->begin());
|
||||
|
||||
Instruction* initial_value =
|
||||
builder.GetIntConstant<uint32_t>(0, int_type_->IsSigned());
|
||||
if (!initial_value) return false;
|
||||
|
||||
canonical_induction_variable_ = builder.AddPhi(
|
||||
uint_1_cst->type_id(),
|
||||
{builder.GetIntConstant<uint32_t>(0, int_type_->IsSigned())->result_id(),
|
||||
GetClonedLoop()->GetPreHeaderBlock()->id(), iv_inc->result_id(),
|
||||
GetClonedLoop()->GetLatchBlock()->id()});
|
||||
{initial_value->result_id(), GetClonedLoop()->GetPreHeaderBlock()->id(),
|
||||
iv_inc->result_id(), GetClonedLoop()->GetLatchBlock()->id()});
|
||||
if (!canonical_induction_variable_) return false;
|
||||
// Connect everything.
|
||||
iv_inc->SetInOperand(0, {canonical_induction_variable_->result_id()});
|
||||
|
||||
@@ -191,6 +206,7 @@ void LoopPeeling::InsertCanonicalInductionVariable(
|
||||
if (do_while_form_) {
|
||||
canonical_induction_variable_ = iv_inc;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void LoopPeeling::GetIteratorUpdateOperations(
|
||||
@@ -308,7 +324,7 @@ void LoopPeeling::GetIteratingExitValues() {
|
||||
}
|
||||
}
|
||||
|
||||
void LoopPeeling::FixExitCondition(
|
||||
bool LoopPeeling::FixExitCondition(
|
||||
const std::function<uint32_t(Instruction*)>& condition_builder) {
|
||||
CFG& cfg = *context_->cfg();
|
||||
|
||||
@@ -329,7 +345,11 @@ void LoopPeeling::FixExitCondition(
|
||||
--insert_point;
|
||||
}
|
||||
|
||||
exit_condition->SetInOperand(0, {condition_builder(&*insert_point)});
|
||||
uint32_t new_cond_id = condition_builder(&*insert_point);
|
||||
if (new_cond_id == 0) {
|
||||
return false;
|
||||
}
|
||||
exit_condition->SetInOperand(0, {new_cond_id});
|
||||
|
||||
uint32_t to_continue_block_idx =
|
||||
GetClonedLoop()->IsInsideLoop(exit_condition->GetSingleWordInOperand(1))
|
||||
@@ -341,6 +361,7 @@ void LoopPeeling::FixExitCondition(
|
||||
|
||||
// Update def/use manager.
|
||||
context_->get_def_use_mgr()->AnalyzeInstUse(exit_condition);
|
||||
return true;
|
||||
}
|
||||
|
||||
BasicBlock* LoopPeeling::CreateBlockBefore(BasicBlock* bb) {
|
||||
@@ -348,10 +369,13 @@ BasicBlock* LoopPeeling::CreateBlockBefore(BasicBlock* bb) {
|
||||
CFG& cfg = *context_->cfg();
|
||||
assert(cfg.preds(bb->id()).size() == 1 && "More than one predecessor");
|
||||
|
||||
// TODO(1841): Handle id overflow.
|
||||
uint32_t new_id = context_->TakeNextId();
|
||||
if (new_id == 0) {
|
||||
return nullptr;
|
||||
}
|
||||
std::unique_ptr<BasicBlock> new_bb =
|
||||
MakeUnique<BasicBlock>(std::unique_ptr<Instruction>(new Instruction(
|
||||
context_, spv::Op::OpLabel, 0, context_->TakeNextId(), {})));
|
||||
MakeUnique<BasicBlock>(std::unique_ptr<Instruction>(
|
||||
new Instruction(context_, spv::Op::OpLabel, 0, new_id, {})));
|
||||
// Update the loop descriptor.
|
||||
Loop* in_loop = (*loop_utils_.GetLoopDescriptor())[bb];
|
||||
if (in_loop) {
|
||||
@@ -394,8 +418,10 @@ BasicBlock* LoopPeeling::CreateBlockBefore(BasicBlock* bb) {
|
||||
|
||||
BasicBlock* LoopPeeling::ProtectLoop(Loop* loop, Instruction* condition,
|
||||
BasicBlock* if_merge) {
|
||||
// TODO(1841): Handle failure to create pre-header.
|
||||
BasicBlock* if_block = loop->GetOrCreatePreHeaderBlock();
|
||||
if (!if_block) {
|
||||
return nullptr;
|
||||
}
|
||||
// Will no longer be a pre-header because of the if.
|
||||
loop->SetPreHeaderBlock(nullptr);
|
||||
// Kill the branch to the header.
|
||||
@@ -411,48 +437,63 @@ BasicBlock* LoopPeeling::ProtectLoop(Loop* loop, Instruction* condition,
|
||||
return if_block;
|
||||
}
|
||||
|
||||
void LoopPeeling::PeelBefore(uint32_t peel_factor) {
|
||||
bool LoopPeeling::PeelBefore(uint32_t peel_factor) {
|
||||
assert(CanPeelLoop() && "Cannot peel loop");
|
||||
LoopUtils::LoopCloningResult clone_results;
|
||||
|
||||
// Clone the loop and insert the cloned one before the loop.
|
||||
DuplicateAndConnectLoop(&clone_results);
|
||||
if (!DuplicateAndConnectLoop(&clone_results)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Add a canonical induction variable "canonical_induction_variable_".
|
||||
InsertCanonicalInductionVariable(&clone_results);
|
||||
if (!InsertCanonicalInductionVariable(&clone_results)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
InstructionBuilder builder(
|
||||
context_, &*cloned_loop_->GetPreHeaderBlock()->tail(),
|
||||
IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
|
||||
Instruction* factor =
|
||||
builder.GetIntConstant(peel_factor, int_type_->IsSigned());
|
||||
if (!factor) return false;
|
||||
|
||||
Instruction* has_remaining_iteration = builder.AddLessThan(
|
||||
factor->result_id(), loop_iteration_count_->result_id());
|
||||
if (!has_remaining_iteration) return false;
|
||||
Instruction* max_iteration = builder.AddSelect(
|
||||
factor->type_id(), has_remaining_iteration->result_id(),
|
||||
factor->result_id(), loop_iteration_count_->result_id());
|
||||
if (!max_iteration) return false;
|
||||
|
||||
// Change the exit condition of the cloned loop to be (exit when become
|
||||
// false):
|
||||
// "canonical_induction_variable_" < min("factor", "loop_iteration_count_")
|
||||
FixExitCondition([max_iteration, this](Instruction* insert_before_point) {
|
||||
return InstructionBuilder(context_, insert_before_point,
|
||||
IRContext::kAnalysisDefUse |
|
||||
IRContext::kAnalysisInstrToBlockMapping)
|
||||
.AddLessThan(canonical_induction_variable_->result_id(),
|
||||
max_iteration->result_id())
|
||||
->result_id();
|
||||
});
|
||||
if (!FixExitCondition(
|
||||
[max_iteration, this](Instruction* insert_before_point) {
|
||||
Instruction* new_cond =
|
||||
InstructionBuilder(context_, insert_before_point,
|
||||
IRContext::kAnalysisDefUse |
|
||||
IRContext::kAnalysisInstrToBlockMapping)
|
||||
.AddLessThan(canonical_induction_variable_->result_id(),
|
||||
max_iteration->result_id());
|
||||
return new_cond ? new_cond->result_id() : 0;
|
||||
})) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// "Protect" the second loop: the second loop can only be executed if
|
||||
// |has_remaining_iteration| is true (i.e. factor < loop_iteration_count_).
|
||||
BasicBlock* if_merge_block = loop_->GetMergeBlock();
|
||||
loop_->SetMergeBlock(CreateBlockBefore(loop_->GetMergeBlock()));
|
||||
BasicBlock* new_merge_block = CreateBlockBefore(loop_->GetMergeBlock());
|
||||
if (!new_merge_block) return false;
|
||||
loop_->SetMergeBlock(new_merge_block);
|
||||
// Prevent the second loop from being executed if we already executed all the
|
||||
// required iterations.
|
||||
BasicBlock* if_block =
|
||||
ProtectLoop(loop_, has_remaining_iteration, if_merge_block);
|
||||
if (!if_block) return false;
|
||||
|
||||
// Patch the phi of the merge block.
|
||||
if_merge_block->ForEachPhiInst(
|
||||
[&clone_results, if_block, this](Instruction* phi) {
|
||||
@@ -471,14 +512,17 @@ void LoopPeeling::PeelBefore(uint32_t peel_factor) {
|
||||
context_->InvalidateAnalysesExceptFor(
|
||||
IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping |
|
||||
IRContext::kAnalysisLoopAnalysis | IRContext::kAnalysisCFG);
|
||||
return true;
|
||||
}
|
||||
|
||||
void LoopPeeling::PeelAfter(uint32_t peel_factor) {
|
||||
bool LoopPeeling::PeelAfter(uint32_t peel_factor) {
|
||||
assert(CanPeelLoop() && "Cannot peel loop");
|
||||
LoopUtils::LoopCloningResult clone_results;
|
||||
|
||||
// Clone the loop and insert the cloned one before the loop.
|
||||
DuplicateAndConnectLoop(&clone_results);
|
||||
if (!DuplicateAndConnectLoop(&clone_results)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Add a canonical induction variable "canonical_induction_variable_".
|
||||
InsertCanonicalInductionVariable(&clone_results);
|
||||
@@ -488,28 +532,33 @@ void LoopPeeling::PeelAfter(uint32_t peel_factor) {
|
||||
IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
|
||||
Instruction* factor =
|
||||
builder.GetIntConstant(peel_factor, int_type_->IsSigned());
|
||||
if (!factor) return false;
|
||||
|
||||
Instruction* has_remaining_iteration = builder.AddLessThan(
|
||||
factor->result_id(), loop_iteration_count_->result_id());
|
||||
if (!has_remaining_iteration) return false;
|
||||
|
||||
// Change the exit condition of the cloned loop to be (exit when become
|
||||
// false):
|
||||
// "canonical_induction_variable_" + "factor" < "loop_iteration_count_"
|
||||
FixExitCondition([factor, this](Instruction* insert_before_point) {
|
||||
InstructionBuilder cond_builder(
|
||||
context_, insert_before_point,
|
||||
IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
|
||||
// Build the following check: canonical_induction_variable_ + factor <
|
||||
// iteration_count
|
||||
return cond_builder
|
||||
.AddLessThan(cond_builder
|
||||
.AddIAdd(canonical_induction_variable_->type_id(),
|
||||
canonical_induction_variable_->result_id(),
|
||||
factor->result_id())
|
||||
->result_id(),
|
||||
loop_iteration_count_->result_id())
|
||||
->result_id();
|
||||
});
|
||||
if (!FixExitCondition([factor,
|
||||
this](Instruction* insert_before_point) -> uint32_t {
|
||||
InstructionBuilder cond_builder(
|
||||
context_, insert_before_point,
|
||||
IRContext::kAnalysisDefUse |
|
||||
IRContext::kAnalysisInstrToBlockMapping);
|
||||
// Build the following check: canonical_induction_variable_ + factor <
|
||||
// iteration_count
|
||||
Instruction* add = cond_builder.AddIAdd(
|
||||
canonical_induction_variable_->type_id(),
|
||||
canonical_induction_variable_->result_id(), factor->result_id());
|
||||
if (!add) return 0;
|
||||
Instruction* new_cond = cond_builder.AddLessThan(
|
||||
add->result_id(), loop_iteration_count_->result_id());
|
||||
return new_cond ? new_cond->result_id() : 0;
|
||||
})) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// "Protect" the first loop: the first loop can only be executed if
|
||||
// factor < loop_iteration_count_.
|
||||
@@ -517,11 +566,17 @@ void LoopPeeling::PeelAfter(uint32_t peel_factor) {
|
||||
// The original loop's pre-header was the cloned loop merge block.
|
||||
GetClonedLoop()->SetMergeBlock(
|
||||
CreateBlockBefore(GetOriginalLoop()->GetPreHeaderBlock()));
|
||||
if (!GetClonedLoop()->GetMergeBlock()) {
|
||||
return false;
|
||||
}
|
||||
// Use the second loop preheader as if merge block.
|
||||
|
||||
// Prevent the first loop if only the peeled loop needs it.
|
||||
BasicBlock* if_block = ProtectLoop(cloned_loop_, has_remaining_iteration,
|
||||
GetOriginalLoop()->GetPreHeaderBlock());
|
||||
if (!if_block) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Patch the phi of the header block.
|
||||
// We added an if to enclose the first loop and because the phi node are
|
||||
@@ -529,8 +584,10 @@ void LoopPeeling::PeelAfter(uint32_t peel_factor) {
|
||||
// dominate the preheader.
|
||||
// We had to the preheader (our if merge block) the required phi instruction
|
||||
// and patch the header phi.
|
||||
bool ok = true;
|
||||
GetOriginalLoop()->GetHeaderBlock()->ForEachPhiInst(
|
||||
[&clone_results, if_block, this](Instruction* phi) {
|
||||
[&clone_results, if_block, &ok, this](Instruction* phi) {
|
||||
if (!ok) return;
|
||||
analysis::DefUseManager* def_use_mgr = context_->get_def_use_mgr();
|
||||
|
||||
auto find_value_idx = [](Instruction* phi_inst, Loop* loop) {
|
||||
@@ -554,15 +611,21 @@ void LoopPeeling::PeelAfter(uint32_t peel_factor) {
|
||||
find_value_idx(phi, GetOriginalLoop())),
|
||||
GetClonedLoop()->GetMergeBlock()->id(),
|
||||
cloned_preheader_value, if_block->id()});
|
||||
if (!new_phi) {
|
||||
ok = false;
|
||||
return;
|
||||
}
|
||||
|
||||
phi->SetInOperand(find_value_idx(phi, GetOriginalLoop()),
|
||||
{new_phi->result_id()});
|
||||
def_use_mgr->AnalyzeInstUse(phi);
|
||||
});
|
||||
if (!ok) return false;
|
||||
|
||||
context_->InvalidateAnalysesExceptFor(
|
||||
IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping |
|
||||
IRContext::kAnalysisLoopAnalysis | IRContext::kAnalysisCFG);
|
||||
return true;
|
||||
}
|
||||
|
||||
Pass::Status LoopPeelingPass::Process() {
|
||||
@@ -571,13 +634,19 @@ Pass::Status LoopPeelingPass::Process() {
|
||||
|
||||
// Process each function in the module
|
||||
for (Function& f : *module) {
|
||||
modified |= ProcessFunction(&f);
|
||||
Pass::Status status = ProcessFunction(&f);
|
||||
if (status == Status::Failure) {
|
||||
return Status::Failure;
|
||||
}
|
||||
if (status == Status::SuccessWithChange) {
|
||||
modified = true;
|
||||
}
|
||||
}
|
||||
|
||||
return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange;
|
||||
}
|
||||
|
||||
bool LoopPeelingPass::ProcessFunction(Function* f) {
|
||||
Pass::Status LoopPeelingPass::ProcessFunction(Function* f) {
|
||||
bool modified = false;
|
||||
LoopDescriptor& loop_descriptor = *context()->GetLoopDescriptor(f);
|
||||
|
||||
@@ -593,41 +662,54 @@ bool LoopPeelingPass::ProcessFunction(Function* f) {
|
||||
CodeMetrics loop_size;
|
||||
loop_size.Analyze(*loop);
|
||||
|
||||
auto try_peel = [&loop_size, &modified, this](Loop* loop_to_peel) -> Loop* {
|
||||
auto try_peel = [&loop_size, &modified, this](
|
||||
Loop* loop_to_peel) -> std::pair<Pass::Status, Loop*> {
|
||||
if (!loop_to_peel->IsLCSSA()) {
|
||||
LoopUtils(context(), loop_to_peel).MakeLoopClosedSSA();
|
||||
}
|
||||
|
||||
bool peeled_loop;
|
||||
Pass::Status status;
|
||||
Loop* still_peelable_loop;
|
||||
std::tie(peeled_loop, still_peelable_loop) =
|
||||
std::tie(status, still_peelable_loop) =
|
||||
ProcessLoop(loop_to_peel, &loop_size);
|
||||
|
||||
if (peeled_loop) {
|
||||
if (status == Pass::Status::SuccessWithChange) {
|
||||
modified = true;
|
||||
}
|
||||
|
||||
return still_peelable_loop;
|
||||
return {status, still_peelable_loop};
|
||||
};
|
||||
|
||||
Loop* still_peelable_loop = try_peel(loop);
|
||||
Pass::Status status;
|
||||
Loop* still_peelable_loop;
|
||||
std::tie(status, still_peelable_loop) = try_peel(loop);
|
||||
|
||||
if (status == Pass::Status::Failure) {
|
||||
return Pass::Status::Failure;
|
||||
}
|
||||
|
||||
// The pass is working out the maximum factor by which a loop can be peeled.
|
||||
// If the loop can potentially be peeled again, then there is only one
|
||||
// possible direction, so only one call is still needed.
|
||||
if (still_peelable_loop) {
|
||||
try_peel(loop);
|
||||
std::tie(status, still_peelable_loop) = try_peel(still_peelable_loop);
|
||||
if (status == Pass::Status::Failure) {
|
||||
return Pass::Status::Failure;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return modified;
|
||||
return modified ? Pass::Status::SuccessWithChange
|
||||
: Pass::Status::SuccessWithoutChange;
|
||||
}
|
||||
|
||||
std::pair<bool, Loop*> LoopPeelingPass::ProcessLoop(Loop* loop,
|
||||
CodeMetrics* loop_size) {
|
||||
std::tuple<Pass::Status, Loop*> LoopPeelingPass::ProcessLoop(
|
||||
Loop* loop, CodeMetrics* loop_size) {
|
||||
ScalarEvolutionAnalysis* scev_analysis =
|
||||
context()->GetScalarEvolutionAnalysis();
|
||||
// Default values for bailing out.
|
||||
std::pair<bool, Loop*> bail_out{false, nullptr};
|
||||
std::tuple<Pass::Status, Loop*> bail_out{Pass::Status::SuccessWithoutChange,
|
||||
nullptr};
|
||||
|
||||
BasicBlock* exit_block = loop->FindConditionBlock();
|
||||
if (!exit_block) {
|
||||
@@ -744,7 +826,9 @@ std::pair<bool, Loop*> LoopPeelingPass::ProcessLoop(Loop* loop,
|
||||
Loop* extra_opportunity = nullptr;
|
||||
|
||||
if (direction == PeelDirection::kBefore) {
|
||||
peeler.PeelBefore(factor);
|
||||
if (!peeler.PeelBefore(factor)) {
|
||||
return {Pass::Status::Failure, nullptr};
|
||||
}
|
||||
if (stats_) {
|
||||
stats_->peeled_loops_.emplace_back(loop, PeelDirection::kBefore, factor);
|
||||
}
|
||||
@@ -753,7 +837,9 @@ std::pair<bool, Loop*> LoopPeelingPass::ProcessLoop(Loop* loop,
|
||||
extra_opportunity = peeler.GetOriginalLoop();
|
||||
}
|
||||
} else {
|
||||
peeler.PeelAfter(factor);
|
||||
if (!peeler.PeelAfter(factor)) {
|
||||
return {Pass::Status::Failure, nullptr};
|
||||
}
|
||||
if (stats_) {
|
||||
stats_->peeled_loops_.emplace_back(loop, PeelDirection::kAfter, factor);
|
||||
}
|
||||
@@ -763,7 +849,7 @@ std::pair<bool, Loop*> LoopPeelingPass::ProcessLoop(Loop* loop,
|
||||
}
|
||||
}
|
||||
|
||||
return {true, extra_opportunity};
|
||||
return {Pass::Status::SuccessWithChange, extra_opportunity};
|
||||
}
|
||||
|
||||
uint32_t LoopPeelingPass::LoopPeelingInfo::GetFirstLoopInvariantOperand(
|
||||
|
||||
26
3rdparty/spirv-tools/source/opt/loop_peeling.h
vendored
26
3rdparty/spirv-tools/source/opt/loop_peeling.h
vendored
@@ -148,11 +148,11 @@ class LoopPeeling {
|
||||
|
||||
// Moves the execution of the |factor| first iterations of the loop into a
|
||||
// dedicated loop.
|
||||
void PeelBefore(uint32_t factor);
|
||||
bool PeelBefore(uint32_t factor);
|
||||
|
||||
// Moves the execution of the |factor| last iterations of the loop into a
|
||||
// dedicated loop.
|
||||
void PeelAfter(uint32_t factor);
|
||||
bool PeelAfter(uint32_t factor);
|
||||
|
||||
// Returns the cloned loop.
|
||||
Loop* GetClonedLoop() { return cloned_loop_; }
|
||||
@@ -184,19 +184,19 @@ class LoopPeeling {
|
||||
// Duplicate |loop_| and place the new loop before the cloned loop. Iterating
|
||||
// values from the cloned loop are then connected to the original loop as
|
||||
// initializer.
|
||||
void DuplicateAndConnectLoop(LoopUtils::LoopCloningResult* clone_results);
|
||||
bool DuplicateAndConnectLoop(LoopUtils::LoopCloningResult* clone_results);
|
||||
|
||||
// Insert the canonical induction variable into the first loop as a simplified
|
||||
// counter.
|
||||
void InsertCanonicalInductionVariable(
|
||||
// counter. Returns true on success.
|
||||
bool InsertCanonicalInductionVariable(
|
||||
LoopUtils::LoopCloningResult* clone_results);
|
||||
|
||||
// Fixes the exit condition of the before loop. The function calls
|
||||
// |condition_builder| to get the condition to use in the conditional branch
|
||||
// of the loop exit. The loop will be exited if the condition evaluate to
|
||||
// true. |condition_builder| takes an Instruction* that represent the
|
||||
// insertion point.
|
||||
void FixExitCondition(
|
||||
// insertion point. Returns true on success.
|
||||
bool FixExitCondition(
|
||||
const std::function<uint32_t(Instruction*)>& condition_builder);
|
||||
|
||||
// Gathers all operations involved in the update of |iterator| into
|
||||
@@ -321,10 +321,14 @@ class LoopPeelingPass : public Pass {
|
||||
ScalarEvolutionAnalysis* scev_analysis_;
|
||||
size_t loop_max_iterations_;
|
||||
};
|
||||
// Peel profitable loops in |f|.
|
||||
bool ProcessFunction(Function* f);
|
||||
// Peel |loop| if profitable.
|
||||
std::pair<bool, Loop*> ProcessLoop(Loop* loop, CodeMetrics* loop_size);
|
||||
// Peel profitable loops in |f|. Returns Pass::Status::Failure if an error
|
||||
// occurs.
|
||||
Pass::Status ProcessFunction(Function* f);
|
||||
// Peel |loop| if profitable. Returns Pass::Status::Failure if an error
|
||||
// occurs. Returns {Pass::Status::SuccessWithChange, Loop*} if the loop is
|
||||
// peeled and there is another peeling opportunity.
|
||||
std::tuple<Pass::Status, Loop*> ProcessLoop(Loop* loop,
|
||||
CodeMetrics* loop_size);
|
||||
|
||||
static size_t code_grow_threshold_;
|
||||
LoopPeelingStats* stats_;
|
||||
|
||||
@@ -92,12 +92,16 @@ class LoopUnswitch {
|
||||
// position |ip|. This function preserves the def/use and instr to block
|
||||
// managers.
|
||||
BasicBlock* CreateBasicBlock(Function::iterator ip) {
|
||||
uint32_t new_label_id = TakeNextId();
|
||||
if (new_label_id == 0) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
analysis::DefUseManager* def_use_mgr = context_->get_def_use_mgr();
|
||||
|
||||
// TODO(1841): Handle id overflow.
|
||||
BasicBlock* bb = &*ip.InsertBefore(std::unique_ptr<BasicBlock>(
|
||||
new BasicBlock(std::unique_ptr<Instruction>(new Instruction(
|
||||
context_, spv::Op::OpLabel, 0, context_->TakeNextId(), {})))));
|
||||
context_, spv::Op::OpLabel, 0, new_label_id, {})))));
|
||||
bb->SetParent(function_);
|
||||
def_use_mgr->AnalyzeInstDef(bb->GetLabelInst());
|
||||
context_->set_instr_block(bb->GetLabelInst(), bb);
|
||||
@@ -135,7 +139,7 @@ class LoopUnswitch {
|
||||
}
|
||||
|
||||
// Unswitches |loop_|.
|
||||
void PerformUnswitch() {
|
||||
bool PerformUnswitch() {
|
||||
assert(CanUnswitchLoop() &&
|
||||
"Cannot unswitch if there is not constant condition");
|
||||
assert(loop_->GetPreHeaderBlock() && "This loop has no pre-header block");
|
||||
@@ -165,6 +169,9 @@ class LoopUnswitch {
|
||||
if_merge_block
|
||||
? CreateBasicBlock(FindBasicBlockPosition(if_merge_block))
|
||||
: nullptr;
|
||||
if (if_merge_block && !loop_merge_block) {
|
||||
return false;
|
||||
}
|
||||
if (loop_merge_block) {
|
||||
// Add the instruction and update managers.
|
||||
InstructionBuilder builder(
|
||||
@@ -174,17 +181,24 @@ class LoopUnswitch {
|
||||
builder.SetInsertPoint(&*loop_merge_block->begin());
|
||||
cfg.RegisterBlock(loop_merge_block);
|
||||
def_use_mgr->AnalyzeInstDef(loop_merge_block->GetLabelInst());
|
||||
// Update CFG.
|
||||
bool ok = true;
|
||||
if_merge_block->ForEachPhiInst(
|
||||
[loop_merge_block, &builder, this](Instruction* phi) {
|
||||
[loop_merge_block, &ok, &builder, this](Instruction* phi) -> bool {
|
||||
Instruction* cloned = phi->Clone(context_);
|
||||
cloned->SetResultId(TakeNextId());
|
||||
uint32_t new_id = TakeNextId();
|
||||
if (new_id == 0) {
|
||||
ok = false;
|
||||
return false;
|
||||
}
|
||||
cloned->SetResultId(new_id);
|
||||
builder.AddInstruction(std::unique_ptr<Instruction>(cloned));
|
||||
phi->SetInOperand(0, {cloned->result_id()});
|
||||
phi->SetInOperand(1, {loop_merge_block->id()});
|
||||
for (uint32_t j = phi->NumInOperands() - 1; j > 1; j--)
|
||||
phi->RemoveInOperand(j);
|
||||
return true;
|
||||
});
|
||||
if (!ok) return false;
|
||||
// Copy the predecessor list (will get invalidated otherwise).
|
||||
std::vector<uint32_t> preds = cfg.preds(if_merge_block->id());
|
||||
for (uint32_t pid : preds) {
|
||||
@@ -227,6 +241,9 @@ class LoopUnswitch {
|
||||
// we need to create a dedicated block for the if.
|
||||
BasicBlock* loop_pre_header =
|
||||
CreateBasicBlock(++FindBasicBlockPosition(if_block));
|
||||
if (!loop_pre_header) {
|
||||
return false;
|
||||
}
|
||||
InstructionBuilder(
|
||||
context_, loop_pre_header,
|
||||
IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping)
|
||||
@@ -308,6 +325,12 @@ class LoopUnswitch {
|
||||
// specific value.
|
||||
original_loop_constant_value =
|
||||
GetValueForDefaultPathForSwitch(iv_condition);
|
||||
if (!original_loop_constant_value) {
|
||||
return false;
|
||||
}
|
||||
if (!original_loop_constant_value) {
|
||||
return false;
|
||||
}
|
||||
|
||||
for (uint32_t i = 2; i < iv_condition->NumInOperands(); i += 2) {
|
||||
constant_branch.emplace_back(
|
||||
@@ -341,6 +364,9 @@ class LoopUnswitch {
|
||||
|
||||
Loop* cloned_loop =
|
||||
loop_utils.CloneLoop(&clone_result, ordered_loop_blocks_);
|
||||
if (!cloned_loop) {
|
||||
return false;
|
||||
}
|
||||
specialisation_pair.second = cloned_loop->GetPreHeaderBlock();
|
||||
|
||||
////////////////////////////////////
|
||||
@@ -416,6 +442,7 @@ class LoopUnswitch {
|
||||
|
||||
context_->InvalidateAnalysesExceptFor(
|
||||
IRContext::Analysis::kAnalysisLoopAnalysis);
|
||||
return true;
|
||||
}
|
||||
|
||||
private:
|
||||
@@ -434,10 +461,7 @@ class LoopUnswitch {
|
||||
std::vector<BasicBlock*> ordered_loop_blocks_;
|
||||
|
||||
// Returns the next usable id for the context.
|
||||
uint32_t TakeNextId() {
|
||||
// TODO(1841): Handle id overflow.
|
||||
return context_->TakeNextId();
|
||||
}
|
||||
uint32_t TakeNextId() { return context_->TakeNextId(); }
|
||||
|
||||
// Simplifies |loop| assuming the instruction |to_version_insn| takes the
|
||||
// value |cst_value|. |block_range| is an iterator range returning the loop
|
||||
@@ -573,13 +597,15 @@ Pass::Status LoopUnswitchPass::Process() {
|
||||
|
||||
// Process each function in the module
|
||||
for (Function& f : *module) {
|
||||
modified |= ProcessFunction(&f);
|
||||
Pass::Status status = ProcessFunction(&f);
|
||||
if (status == Status::Failure) return Status::Failure;
|
||||
if (status == Status::SuccessWithChange) modified = true;
|
||||
}
|
||||
|
||||
return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange;
|
||||
}
|
||||
|
||||
bool LoopUnswitchPass::ProcessFunction(Function* f) {
|
||||
Pass::Status LoopUnswitchPass::ProcessFunction(Function* f) {
|
||||
bool modified = false;
|
||||
std::unordered_set<Loop*> processed_loop;
|
||||
|
||||
@@ -599,15 +625,17 @@ bool LoopUnswitchPass::ProcessFunction(Function* f) {
|
||||
if (!loop.IsLCSSA()) {
|
||||
LoopUtils(context(), &loop).MakeLoopClosedSSA();
|
||||
}
|
||||
if (!unswitcher.PerformUnswitch()) {
|
||||
return Status::Failure;
|
||||
}
|
||||
modified = true;
|
||||
loop_changed = true;
|
||||
unswitcher.PerformUnswitch();
|
||||
}
|
||||
if (loop_changed) break;
|
||||
}
|
||||
}
|
||||
|
||||
return modified;
|
||||
return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange;
|
||||
}
|
||||
|
||||
} // namespace opt
|
||||
|
||||
@@ -34,7 +34,8 @@ class LoopUnswitchPass : public Pass {
|
||||
Pass::Status Process() override;
|
||||
|
||||
private:
|
||||
bool ProcessFunction(Function* f);
|
||||
// Process the given function.
|
||||
Pass::Status ProcessFunction(Function* f);
|
||||
};
|
||||
|
||||
} // namespace opt
|
||||
|
||||
44
3rdparty/spirv-tools/source/opt/loop_utils.cpp
vendored
44
3rdparty/spirv-tools/source/opt/loop_utils.cpp
vendored
@@ -488,12 +488,18 @@ Loop* LoopUtils::CloneLoop(LoopCloningResult* cloning_result) const {
|
||||
|
||||
Loop* LoopUtils::CloneAndAttachLoopToHeader(LoopCloningResult* cloning_result) {
|
||||
// Clone the loop.
|
||||
Loop* new_loop = CloneLoop(cloning_result);
|
||||
Loop* cloned_loop = CloneLoop(cloning_result);
|
||||
if (!cloned_loop) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Create a new exit block/label for the new loop.
|
||||
// TODO(1841): Handle id overflow.
|
||||
std::unique_ptr<Instruction> new_label{new Instruction(
|
||||
context_, spv::Op::OpLabel, 0, context_->TakeNextId(), {})};
|
||||
uint32_t new_label_id = context_->TakeNextId();
|
||||
if (new_label_id == 0) {
|
||||
return nullptr;
|
||||
}
|
||||
std::unique_ptr<Instruction> new_label{
|
||||
new Instruction(context_, spv::Op::OpLabel, 0, new_label_id, {})};
|
||||
std::unique_ptr<BasicBlock> new_exit_bb{new BasicBlock(std::move(new_label))};
|
||||
new_exit_bb->SetParent(loop_->GetMergeBlock()->GetParent());
|
||||
|
||||
@@ -520,7 +526,7 @@ Loop* LoopUtils::CloneAndAttachLoopToHeader(LoopCloningResult* cloning_result) {
|
||||
}
|
||||
|
||||
const uint32_t old_header = loop_->GetHeaderBlock()->id();
|
||||
const uint32_t new_header = new_loop->GetHeaderBlock()->id();
|
||||
const uint32_t new_header = cloned_loop->GetHeaderBlock()->id();
|
||||
analysis::DefUseManager* def_use = context_->get_def_use_mgr();
|
||||
|
||||
def_use->ForEachUse(old_header,
|
||||
@@ -529,22 +535,24 @@ Loop* LoopUtils::CloneAndAttachLoopToHeader(LoopCloningResult* cloning_result) {
|
||||
inst->SetOperand(operand, {new_header});
|
||||
});
|
||||
|
||||
// TODO(1841): Handle failure to create pre-header.
|
||||
BasicBlock* pre_header = loop_->GetOrCreatePreHeaderBlock();
|
||||
if (!pre_header) {
|
||||
return nullptr;
|
||||
}
|
||||
def_use->ForEachUse(
|
||||
loop_->GetOrCreatePreHeaderBlock()->id(),
|
||||
pre_header->id(),
|
||||
[new_merge_block, this](Instruction* inst, uint32_t operand) {
|
||||
if (this->loop_->IsInsideLoop(inst))
|
||||
inst->SetOperand(operand, {new_merge_block});
|
||||
|
||||
});
|
||||
new_loop->SetMergeBlock(new_exit_bb.get());
|
||||
cloned_loop->SetMergeBlock(new_exit_bb.get());
|
||||
|
||||
new_loop->SetPreHeaderBlock(loop_->GetPreHeaderBlock());
|
||||
cloned_loop->SetPreHeaderBlock(loop_->GetPreHeaderBlock());
|
||||
|
||||
// Add the new block into the cloned instructions.
|
||||
cloning_result->cloned_bb_.push_back(std::move(new_exit_bb));
|
||||
|
||||
return new_loop;
|
||||
return cloned_loop;
|
||||
}
|
||||
|
||||
Loop* LoopUtils::CloneLoop(
|
||||
@@ -562,8 +570,11 @@ Loop* LoopUtils::CloneLoop(
|
||||
// between old and new ids.
|
||||
BasicBlock* new_bb = old_bb->Clone(context_);
|
||||
new_bb->SetParent(&function_);
|
||||
// TODO(1841): Handle id overflow.
|
||||
new_bb->GetLabelInst()->SetResultId(context_->TakeNextId());
|
||||
uint32_t new_label_id = context_->TakeNextId();
|
||||
if (new_label_id == 0) {
|
||||
return nullptr;
|
||||
}
|
||||
new_bb->GetLabelInst()->SetResultId(new_label_id);
|
||||
def_use_mgr->AnalyzeInstDef(new_bb->GetLabelInst());
|
||||
context_->set_instr_block(new_bb->GetLabelInst(), new_bb);
|
||||
cloning_result->cloned_bb_.emplace_back(new_bb);
|
||||
@@ -578,8 +589,11 @@ Loop* LoopUtils::CloneLoop(
|
||||
new_inst != new_bb->end(); ++new_inst, ++old_inst) {
|
||||
cloning_result->ptr_map_[&*new_inst] = &*old_inst;
|
||||
if (new_inst->HasResultId()) {
|
||||
// TODO(1841): Handle id overflow.
|
||||
new_inst->SetResultId(context_->TakeNextId());
|
||||
uint32_t new_result_id = context_->TakeNextId();
|
||||
if (new_result_id == 0) {
|
||||
return nullptr;
|
||||
}
|
||||
new_inst->SetResultId(new_result_id);
|
||||
cloning_result->value_map_[old_inst->result_id()] =
|
||||
new_inst->result_id();
|
||||
|
||||
|
||||
2
3rdparty/spirv-tools/source/opt/loop_utils.h
vendored
2
3rdparty/spirv-tools/source/opt/loop_utils.h
vendored
@@ -114,6 +114,7 @@ class LoopUtils {
|
||||
// The function preserves the def/use, cfg and instr to block analyses.
|
||||
// The cloned loop nest will be added to the loop descriptor and will have
|
||||
// ownership.
|
||||
// Returns the cloned loop, or nullptr if the loop could not be cloned.
|
||||
Loop* CloneLoop(LoopCloningResult* cloning_result,
|
||||
const std::vector<BasicBlock*>& ordered_loop_blocks) const;
|
||||
// Clone |loop_| and remap its instructions, as above. Overload to compute
|
||||
@@ -121,6 +122,7 @@ class LoopUtils {
|
||||
Loop* CloneLoop(LoopCloningResult* cloning_result) const;
|
||||
|
||||
// Clone the |loop_| and make the new loop branch to the second loop on exit.
|
||||
// Returns the cloned loop, or nullptr if the loop could not be cloned.
|
||||
Loop* CloneAndAttachLoopToHeader(LoopCloningResult* cloning_result);
|
||||
|
||||
// Perform a partial unroll of |loop| by given |factor|. This will copy the
|
||||
|
||||
@@ -58,7 +58,9 @@ Pass::Status MergeReturnPass::Process() {
|
||||
failed = true;
|
||||
}
|
||||
} else {
|
||||
MergeReturnBlocks(function, return_blocks);
|
||||
if (!MergeReturnBlocks(function, return_blocks)) {
|
||||
failed = true;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
};
|
||||
@@ -171,10 +173,14 @@ bool MergeReturnPass::ProcessStructured(
|
||||
return true;
|
||||
}
|
||||
|
||||
void MergeReturnPass::CreateReturnBlock() {
|
||||
bool MergeReturnPass::CreateReturnBlock() {
|
||||
// Create a label for the new return block
|
||||
uint32_t label_id = TakeNextId();
|
||||
if (label_id == 0) {
|
||||
return false;
|
||||
}
|
||||
std::unique_ptr<Instruction> return_label(
|
||||
new Instruction(context(), spv::Op::OpLabel, 0u, TakeNextId(), {}));
|
||||
new Instruction(context(), spv::Op::OpLabel, 0u, label_id, {}));
|
||||
|
||||
// Create the new basic block
|
||||
std::unique_ptr<BasicBlock> return_block(
|
||||
@@ -186,14 +192,18 @@ void MergeReturnPass::CreateReturnBlock() {
|
||||
final_return_block_);
|
||||
assert(final_return_block_->GetParent() == function_ &&
|
||||
"The function should have been set when the block was created.");
|
||||
return true;
|
||||
}
|
||||
|
||||
void MergeReturnPass::CreateReturn(BasicBlock* block) {
|
||||
bool MergeReturnPass::CreateReturn(BasicBlock* block) {
|
||||
AddReturnValue();
|
||||
|
||||
if (return_value_) {
|
||||
// Load and return the final return value
|
||||
uint32_t loadId = TakeNextId();
|
||||
if (loadId == 0) {
|
||||
return false;
|
||||
}
|
||||
block->AddInstruction(MakeUnique<Instruction>(
|
||||
context(), spv::Op::OpLoad, function_->type_id(), loadId,
|
||||
std::initializer_list<Operand>{
|
||||
@@ -216,6 +226,7 @@ void MergeReturnPass::CreateReturn(BasicBlock* block) {
|
||||
context()->AnalyzeDefUse(block->terminator());
|
||||
context()->set_instr_block(block->terminator(), block);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void MergeReturnPass::ProcessStructuredBlock(BasicBlock* block) {
|
||||
@@ -663,14 +674,16 @@ std::vector<BasicBlock*> MergeReturnPass::CollectReturnBlocks(
|
||||
return return_blocks;
|
||||
}
|
||||
|
||||
void MergeReturnPass::MergeReturnBlocks(
|
||||
bool MergeReturnPass::MergeReturnBlocks(
|
||||
Function* function, const std::vector<BasicBlock*>& return_blocks) {
|
||||
if (return_blocks.size() <= 1) {
|
||||
// No work to do.
|
||||
return;
|
||||
return true;
|
||||
}
|
||||
|
||||
CreateReturnBlock();
|
||||
if (!CreateReturnBlock()) {
|
||||
return false;
|
||||
}
|
||||
uint32_t return_id = final_return_block_->id();
|
||||
auto ret_block_iter = --function->end();
|
||||
// Create the PHI for the merged block (if necessary).
|
||||
@@ -687,6 +700,9 @@ void MergeReturnPass::MergeReturnBlocks(
|
||||
if (!phi_ops.empty()) {
|
||||
// Need a PHI node to select the correct return value.
|
||||
uint32_t phi_result_id = TakeNextId();
|
||||
if (phi_result_id == 0) {
|
||||
return false;
|
||||
}
|
||||
uint32_t phi_type_id = function->type_id();
|
||||
std::unique_ptr<Instruction> phi_inst(new Instruction(
|
||||
context(), spv::Op::OpPhi, phi_type_id, phi_result_id, phi_ops));
|
||||
@@ -718,6 +734,7 @@ void MergeReturnPass::MergeReturnBlocks(
|
||||
}
|
||||
|
||||
get_def_use_mgr()->AnalyzeInstDefUse(ret_block_iter->GetLabelInst());
|
||||
return true;
|
||||
}
|
||||
|
||||
void MergeReturnPass::AddNewPhiNodes() {
|
||||
@@ -781,8 +798,12 @@ void MergeReturnPass::InsertAfterElement(BasicBlock* element,
|
||||
}
|
||||
|
||||
bool MergeReturnPass::AddSingleCaseSwitchAroundFunction() {
|
||||
CreateReturnBlock();
|
||||
CreateReturn(final_return_block_);
|
||||
if (!CreateReturnBlock()) {
|
||||
return false;
|
||||
}
|
||||
if (!CreateReturn(final_return_block_)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (context()->AreAnalysesValid(IRContext::kAnalysisCFG)) {
|
||||
cfg()->RegisterBlock(final_return_block_);
|
||||
@@ -828,7 +849,8 @@ BasicBlock* MergeReturnPass::CreateContinueTarget(uint32_t header_label_id) {
|
||||
|
||||
bool MergeReturnPass::CreateSingleCaseSwitch(BasicBlock* merge_target) {
|
||||
// Insert the switch before any code is run. We have to split the entry
|
||||
// block to make sure the OpVariable instructions remain in the entry block.
|
||||
// block to make sure the OpVariable instructions and DebugFunctionDefinition
|
||||
// instructions remain in the entry block.
|
||||
BasicBlock* start_block = &*function_->begin();
|
||||
auto split_pos = start_block->begin();
|
||||
while (split_pos->opcode() == spv::Op::OpVariable) {
|
||||
@@ -838,6 +860,18 @@ bool MergeReturnPass::CreateSingleCaseSwitch(BasicBlock* merge_target) {
|
||||
BasicBlock* old_block =
|
||||
start_block->SplitBasicBlock(context(), TakeNextId(), split_pos);
|
||||
|
||||
// Find DebugFunctionDefinition inst in the old block, and if we can find it,
|
||||
// move it to the entry block. Since DebugFunctionDefinition is not necessary
|
||||
// after OpVariable inst, we have to traverse the whole block to find it.
|
||||
for (auto pos = old_block->begin(); pos != old_block->end(); ++pos) {
|
||||
if (pos->GetShader100DebugOpcode() ==
|
||||
NonSemanticShaderDebugInfo100DebugFunctionDefinition) {
|
||||
start_block->AddInstruction(MakeUnique<Instruction>(*pos));
|
||||
pos.Erase();
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Add the switch to the end of the entry block.
|
||||
InstructionBuilder builder(
|
||||
context(), start_block,
|
||||
|
||||
@@ -149,8 +149,9 @@ class MergeReturnPass : public MemPass {
|
||||
|
||||
// Creates a new basic block with a single return. If |function| returns a
|
||||
// value, a phi node is created to select the correct value to return.
|
||||
// Replaces old returns with an unconditional branch to the new block.
|
||||
void MergeReturnBlocks(Function* function,
|
||||
// Replaces old returns with an unconditional branch to the new block. Returns
|
||||
// true if successful.
|
||||
bool MergeReturnBlocks(Function* function,
|
||||
const std::vector<BasicBlock*>& returnBlocks);
|
||||
|
||||
// Generate and push new control flow state if |block| contains a merge.
|
||||
@@ -231,11 +232,12 @@ class MergeReturnPass : public MemPass {
|
||||
|
||||
// Add an |OpReturn| or |OpReturnValue| to the end of |block|. If an
|
||||
// |OpReturnValue| is needed, the return value is loaded from |return_value_|.
|
||||
void CreateReturn(BasicBlock* block);
|
||||
// Returns true if successful.
|
||||
bool CreateReturn(BasicBlock* block);
|
||||
|
||||
// Creates a block at the end of the function that will become the single
|
||||
// return block at the end of the pass.
|
||||
void CreateReturnBlock();
|
||||
bool CreateReturnBlock();
|
||||
|
||||
// Creates a Phi node in |merge_block| for the result of |inst|.
|
||||
// Any uses of the result of |inst| that are no longer
|
||||
@@ -332,4 +334,4 @@ class MergeReturnPass : public MemPass {
|
||||
} // namespace opt
|
||||
} // namespace spvtools
|
||||
|
||||
#endif // SOURCE_OPT_MERGE_RETURN_PASS_H_
|
||||
#endif // SOURCE_OPT_MERGE_RETURN_PASS_H_
|
||||
@@ -641,6 +641,8 @@ bool Optimizer::RegisterPassFromFlag(const std::string& flag,
|
||||
RegisterPass(CreateSplitCombinedImageSamplerPass());
|
||||
} else if (pass_name == "resolve-binding-conflicts") {
|
||||
RegisterPass(CreateResolveBindingConflictsPass());
|
||||
} else if (pass_name == "canonicalize-ids") {
|
||||
RegisterPass(CreateCanonicalizeIdsPass());
|
||||
} else {
|
||||
Errorf(consumer(), nullptr, {},
|
||||
"Unknown flag '--%s'. Use --help for a list of valid flags",
|
||||
@@ -1202,6 +1204,11 @@ Optimizer::PassToken CreateResolveBindingConflictsPass() {
|
||||
MakeUnique<opt::ResolveBindingConflictsPass>());
|
||||
}
|
||||
|
||||
Optimizer::PassToken CreateCanonicalizeIdsPass() {
|
||||
return MakeUnique<Optimizer::PassToken::Impl>(
|
||||
MakeUnique<opt::CanonicalizeIdsPass>());
|
||||
}
|
||||
|
||||
} // namespace spvtools
|
||||
|
||||
extern "C" {
|
||||
|
||||
1
3rdparty/spirv-tools/source/opt/passes.h
vendored
1
3rdparty/spirv-tools/source/opt/passes.h
vendored
@@ -21,6 +21,7 @@
|
||||
#include "source/opt/amd_ext_to_khr.h"
|
||||
#include "source/opt/analyze_live_input_pass.h"
|
||||
#include "source/opt/block_merge_pass.h"
|
||||
#include "source/opt/canonicalize_ids_pass.h"
|
||||
#include "source/opt/ccp_pass.h"
|
||||
#include "source/opt/cfg_cleanup_pass.h"
|
||||
#include "source/opt/code_sink.h"
|
||||
|
||||
@@ -29,6 +29,7 @@ namespace opt {
|
||||
|
||||
Pass::Status RemoveDuplicatesPass::Process() {
|
||||
bool modified = RemoveDuplicateCapabilities();
|
||||
modified |= RemoveDuplicateExtensions();
|
||||
modified |= RemoveDuplicatesExtInstImports();
|
||||
modified |= RemoveDuplicateTypes();
|
||||
modified |= RemoveDuplicateDecorations();
|
||||
@@ -36,6 +37,41 @@ Pass::Status RemoveDuplicatesPass::Process() {
|
||||
return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange;
|
||||
}
|
||||
|
||||
bool RemoveDuplicatesPass::RemoveDuplicateExtensions() const {
|
||||
bool modified = false;
|
||||
|
||||
if (context()->extensions().empty()) {
|
||||
return modified;
|
||||
}
|
||||
|
||||
// set of {condition ID, extension name}
|
||||
// ID 0 means unconditional extension, ie., OpExtension, otherwise the ID is
|
||||
// the condition operand of OpConditionalExtensionINTEL.
|
||||
std::set<std::pair<uint32_t, std::string>> extensions;
|
||||
for (auto* inst = &*context()->extension_begin(); inst;) {
|
||||
uint32_t cond_id = 0;
|
||||
uint32_t i_name = 0;
|
||||
if (inst->opcode() == spv::Op::OpConditionalExtensionINTEL) {
|
||||
cond_id = inst->GetOperand(0).AsId();
|
||||
i_name = 1;
|
||||
}
|
||||
|
||||
auto res =
|
||||
extensions.insert({cond_id, inst->GetOperand(i_name).AsString()});
|
||||
|
||||
if (res.second) {
|
||||
// Never seen before, keep it.
|
||||
inst = inst->NextNode();
|
||||
} else {
|
||||
// It's a duplicate, remove it.
|
||||
inst = context()->KillInst(inst);
|
||||
modified = true;
|
||||
}
|
||||
}
|
||||
|
||||
return modified;
|
||||
}
|
||||
|
||||
bool RemoveDuplicatesPass::RemoveDuplicateCapabilities() const {
|
||||
bool modified = false;
|
||||
|
||||
@@ -43,16 +79,27 @@ bool RemoveDuplicatesPass::RemoveDuplicateCapabilities() const {
|
||||
return modified;
|
||||
}
|
||||
|
||||
std::unordered_set<uint32_t> capabilities;
|
||||
for (auto* i = &*context()->capability_begin(); i;) {
|
||||
auto res = capabilities.insert(i->GetSingleWordOperand(0u));
|
||||
// set of {condition ID, capability}
|
||||
// ID 0 means unconditional capability, ie., OpCapability, otherwise the ID is
|
||||
// the condition operand of OpConditionalCapabilityINTEL.
|
||||
std::set<std::pair<uint32_t, uint32_t>> capabilities;
|
||||
for (auto* inst = &*context()->capability_begin(); inst;) {
|
||||
uint32_t cond_id = 0;
|
||||
uint32_t i_cap = 0;
|
||||
if (inst->opcode() == spv::Op::OpConditionalCapabilityINTEL) {
|
||||
cond_id = inst->GetOperand(0).AsId();
|
||||
i_cap = 1;
|
||||
}
|
||||
|
||||
auto res =
|
||||
capabilities.insert({cond_id, inst->GetSingleWordOperand(i_cap)});
|
||||
|
||||
if (res.second) {
|
||||
// Never seen before, keep it.
|
||||
i = i->NextNode();
|
||||
inst = inst->NextNode();
|
||||
} else {
|
||||
// It's a duplicate, remove it.
|
||||
i = context()->KillInst(i);
|
||||
inst = context()->KillInst(inst);
|
||||
modified = true;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -37,6 +37,10 @@ class RemoveDuplicatesPass : public Pass {
|
||||
Status Process() override;
|
||||
|
||||
private:
|
||||
// Remove duplicate extensions from the module
|
||||
//
|
||||
// Returns true if the module was modified, false otherwise.
|
||||
bool RemoveDuplicateExtensions() const;
|
||||
// Remove duplicate capabilities from the module
|
||||
//
|
||||
// Returns true if the module was modified, false otherwise.
|
||||
|
||||
@@ -13,6 +13,7 @@
|
||||
// limitations under the License.
|
||||
|
||||
#include "remove_unused_interface_variables_pass.h"
|
||||
|
||||
#include "source/spirv_constant.h"
|
||||
namespace spvtools {
|
||||
namespace opt {
|
||||
@@ -55,7 +56,9 @@ class RemoveUnusedInterfaceVariablesContext {
|
||||
|
||||
void CollectUsedVariables() {
|
||||
std::queue<uint32_t> roots;
|
||||
roots.push(entry_.GetSingleWordInOperand(1));
|
||||
const int op_i =
|
||||
entry_.opcode() == spv::Op::OpConditionalEntryPointINTEL ? 2 : 1;
|
||||
roots.push(entry_.GetSingleWordInOperand(op_i));
|
||||
parent_.context()->ProcessCallTreeFromRoots(pfn_, &roots);
|
||||
}
|
||||
|
||||
@@ -73,7 +76,9 @@ class RemoveUnusedInterfaceVariablesContext {
|
||||
}
|
||||
|
||||
void Modify() {
|
||||
for (int i = entry_.NumInOperands() - 1; i >= 3; --i)
|
||||
const int min_num_operands =
|
||||
entry_.opcode() == spv::Op::OpConditionalEntryPointINTEL ? 4 : 3;
|
||||
for (int i = entry_.NumInOperands() - 1; i >= min_num_operands; --i)
|
||||
entry_.RemoveInOperand(i);
|
||||
for (auto id : operands_to_add_) {
|
||||
entry_.AddOperand(Operand(SPV_OPERAND_TYPE_ID, {id}));
|
||||
|
||||
@@ -186,7 +186,7 @@ bool ScalarReplacementPass::ReplaceWholeDebugDeclare(
|
||||
Instruction* added_dbg_value =
|
||||
context()->get_debug_info_mgr()->AddDebugValueForDecl(
|
||||
dbg_decl, /*value_id=*/var->result_id(),
|
||||
/*insert_before=*/insert_before, /*scope_and_line=*/dbg_decl);
|
||||
/*insert_before=*/insert_before, /*line=*/dbg_decl);
|
||||
|
||||
if (added_dbg_value == nullptr) return false;
|
||||
added_dbg_value->AddOperand(
|
||||
@@ -475,6 +475,7 @@ void ScalarReplacementPass::CreateVariable(
|
||||
|
||||
if (id == 0) {
|
||||
replacements->push_back(nullptr);
|
||||
return;
|
||||
}
|
||||
|
||||
std::unique_ptr<Instruction> variable(
|
||||
@@ -488,7 +489,10 @@ void ScalarReplacementPass::CreateVariable(
|
||||
Instruction* inst = &*block->begin();
|
||||
|
||||
// If varInst was initialized, make sure to initialize its replacement.
|
||||
GetOrCreateInitialValue(var_inst, index, inst);
|
||||
if (!GetOrCreateInitialValue(var_inst, index, inst)) {
|
||||
replacements->push_back(nullptr);
|
||||
return;
|
||||
}
|
||||
get_def_use_mgr()->AnalyzeInstDefUse(inst);
|
||||
context()->set_instr_block(inst, block);
|
||||
|
||||
@@ -509,11 +513,11 @@ uint32_t ScalarReplacementPass::GetOrCreatePointerType(uint32_t id) {
|
||||
return ptr_type_id;
|
||||
}
|
||||
|
||||
void ScalarReplacementPass::GetOrCreateInitialValue(Instruction* source,
|
||||
bool ScalarReplacementPass::GetOrCreateInitialValue(Instruction* source,
|
||||
uint32_t index,
|
||||
Instruction* newVar) {
|
||||
assert(source->opcode() == spv::Op::OpVariable);
|
||||
if (source->NumInOperands() < 2) return;
|
||||
if (source->NumInOperands() < 2) return true;
|
||||
|
||||
uint32_t initId = source->GetSingleWordInOperand(1u);
|
||||
uint32_t storageId = GetStorageType(newVar)->result_id();
|
||||
@@ -525,6 +529,7 @@ void ScalarReplacementPass::GetOrCreateInitialValue(Instruction* source,
|
||||
auto iter = type_to_null_.find(storageId);
|
||||
if (iter == type_to_null_.end()) {
|
||||
newInitId = TakeNextId();
|
||||
if (newInitId == 0) return false;
|
||||
type_to_null_[storageId] = newInitId;
|
||||
context()->AddGlobalValue(
|
||||
MakeUnique<Instruction>(context(), spv::Op::OpConstantNull, storageId,
|
||||
@@ -537,6 +542,7 @@ void ScalarReplacementPass::GetOrCreateInitialValue(Instruction* source,
|
||||
} else if (IsSpecConstantInst(init->opcode())) {
|
||||
// Create a new constant extract.
|
||||
newInitId = TakeNextId();
|
||||
if (newInitId == 0) return false;
|
||||
context()->AddGlobalValue(MakeUnique<Instruction>(
|
||||
context(), spv::Op::OpSpecConstantOp, storageId, newInitId,
|
||||
std::initializer_list<Operand>{
|
||||
@@ -561,6 +567,7 @@ void ScalarReplacementPass::GetOrCreateInitialValue(Instruction* source,
|
||||
if (newInitId != 0) {
|
||||
newVar->AddOperand({SPV_OPERAND_TYPE_ID, {newInitId}});
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
uint64_t ScalarReplacementPass::GetArrayLength(
|
||||
|
||||
@@ -199,7 +199,9 @@ class ScalarReplacementPass : public MemPass {
|
||||
// If there is an initial value for |source| for element |index|, it is
|
||||
// appended as an operand on |newVar|. If the initial value is OpUndef, no
|
||||
// initial value is added to |newVar|.
|
||||
void GetOrCreateInitialValue(Instruction* source, uint32_t index,
|
||||
//
|
||||
// Returns true if the value was successfully created.
|
||||
bool GetOrCreateInitialValue(Instruction* source, uint32_t index,
|
||||
Instruction* newVar);
|
||||
|
||||
// Replaces the load to the entire composite.
|
||||
|
||||
@@ -556,10 +556,14 @@ spv_result_t SplitCombinedImageSamplerPass::RemapFunctions() {
|
||||
Instruction* sampler;
|
||||
};
|
||||
std::vector<Replacement> replacements;
|
||||
bool error = false;
|
||||
|
||||
Function::RewriteParamFn rewriter =
|
||||
[&](std::unique_ptr<Instruction>&& param,
|
||||
std::back_insert_iterator<Function::ParamList>& appender) {
|
||||
if (error) {
|
||||
return;
|
||||
}
|
||||
if (combined_types_.count(param->type_id()) == 0) {
|
||||
appender = std::move(param);
|
||||
return;
|
||||
@@ -569,12 +573,22 @@ spv_result_t SplitCombinedImageSamplerPass::RemapFunctions() {
|
||||
auto* combined_inst = param.release();
|
||||
auto* combined_type = def_use_mgr_->GetDef(combined_inst->type_id());
|
||||
auto [image_type, sampler_type] = SplitType(*combined_type);
|
||||
uint32_t image_param_id = context()->TakeNextId();
|
||||
if (image_param_id == 0) {
|
||||
error = true;
|
||||
return;
|
||||
}
|
||||
auto image_param = MakeUnique<Instruction>(
|
||||
context(), spv::Op::OpFunctionParameter, image_type->result_id(),
|
||||
context()->TakeNextId(), Instruction::OperandList{});
|
||||
image_param_id, Instruction::OperandList{});
|
||||
uint32_t sampler_param_id = context()->TakeNextId();
|
||||
if (sampler_param_id == 0) {
|
||||
error = true;
|
||||
return;
|
||||
}
|
||||
auto sampler_param = MakeUnique<Instruction>(
|
||||
context(), spv::Op::OpFunctionParameter,
|
||||
sampler_type->result_id(), context()->TakeNextId(),
|
||||
sampler_type->result_id(), sampler_param_id,
|
||||
Instruction::OperandList{});
|
||||
replacements.push_back(
|
||||
{combined_inst, image_param.get(), sampler_param.get()});
|
||||
@@ -583,6 +597,10 @@ spv_result_t SplitCombinedImageSamplerPass::RemapFunctions() {
|
||||
};
|
||||
fn.RewriteParams(rewriter);
|
||||
|
||||
if (error) {
|
||||
return SPV_ERROR_INTERNAL;
|
||||
}
|
||||
|
||||
for (auto& r : replacements) {
|
||||
modified_ = true;
|
||||
def_use_mgr_->AnalyzeInstDefUse(r.image);
|
||||
|
||||
@@ -87,13 +87,15 @@ std::string SSARewriter::PhiCandidate::PrettyPrint(const CFG* cfg) const {
|
||||
return str.str();
|
||||
}
|
||||
|
||||
SSARewriter::PhiCandidate& SSARewriter::CreatePhiCandidate(uint32_t var_id,
|
||||
SSARewriter::PhiCandidate* SSARewriter::CreatePhiCandidate(uint32_t var_id,
|
||||
BasicBlock* bb) {
|
||||
// TODO(1841): Handle id overflow.
|
||||
uint32_t phi_result_id = pass_->context()->TakeNextId();
|
||||
if (phi_result_id == 0) {
|
||||
return nullptr;
|
||||
}
|
||||
auto result = phi_candidates_.emplace(
|
||||
phi_result_id, PhiCandidate(var_id, phi_result_id, bb));
|
||||
PhiCandidate& phi_candidate = result.first->second;
|
||||
PhiCandidate* phi_candidate = &result.first->second;
|
||||
return phi_candidate;
|
||||
}
|
||||
|
||||
@@ -268,11 +270,12 @@ uint32_t SSARewriter::GetReachingDef(uint32_t var_id, BasicBlock* bb) {
|
||||
// If there is more than one predecessor, this is a join block which may
|
||||
// require a Phi instruction. This will act as |var_id|'s current
|
||||
// definition to break potential cycles.
|
||||
PhiCandidate& phi_candidate = CreatePhiCandidate(var_id, bb);
|
||||
PhiCandidate* phi_candidate = CreatePhiCandidate(var_id, bb);
|
||||
if (phi_candidate == nullptr) return 0;
|
||||
|
||||
// Set the value for |bb| to avoid an infinite recursion.
|
||||
WriteVariable(var_id, bb, phi_candidate.result_id());
|
||||
val_id = AddPhiOperands(&phi_candidate);
|
||||
WriteVariable(var_id, bb, phi_candidate->result_id());
|
||||
val_id = AddPhiOperands(phi_candidate);
|
||||
}
|
||||
|
||||
// If we could not find a store for this variable in the path from the root
|
||||
|
||||
@@ -232,7 +232,7 @@ class SSARewriter {
|
||||
// during rewriting.
|
||||
//
|
||||
// Once the candidate Phi is created, it returns its ID.
|
||||
PhiCandidate& CreatePhiCandidate(uint32_t var_id, BasicBlock* bb);
|
||||
PhiCandidate* CreatePhiCandidate(uint32_t var_id, BasicBlock* bb);
|
||||
|
||||
// Attempts to remove a trivial Phi candidate |phi_cand|. Trivial Phis are
|
||||
// those that only reference themselves and one other value |val| any number
|
||||
|
||||
@@ -53,17 +53,15 @@ bool IsPowerOf2(uint32_t val) {
|
||||
|
||||
Pass::Status StrengthReductionPass::Process() {
|
||||
// Initialize the member variables on a per module basis.
|
||||
bool modified = false;
|
||||
int32_type_id_ = 0;
|
||||
uint32_type_id_ = 0;
|
||||
std::memset(constant_ids_, 0, sizeof(constant_ids_));
|
||||
|
||||
FindIntTypesAndConstants();
|
||||
modified = ScanFunctions();
|
||||
return (modified ? Status::SuccessWithChange : Status::SuccessWithoutChange);
|
||||
return ScanFunctions();
|
||||
}
|
||||
|
||||
bool StrengthReductionPass::ReplaceMultiplyByPowerOf2(
|
||||
Pass::Status StrengthReductionPass::ReplaceMultiplyByPowerOf2(
|
||||
BasicBlock::iterator* inst) {
|
||||
assert((*inst)->opcode() == spv::Op::OpIMul &&
|
||||
"Only works for multiplication of integers.");
|
||||
@@ -72,7 +70,7 @@ bool StrengthReductionPass::ReplaceMultiplyByPowerOf2(
|
||||
// Currently only works on 32-bit integers.
|
||||
if ((*inst)->type_id() != int32_type_id_ &&
|
||||
(*inst)->type_id() != uint32_type_id_) {
|
||||
return modified;
|
||||
return Status::SuccessWithoutChange;
|
||||
}
|
||||
|
||||
// Check the operands for a constant that is a power of 2.
|
||||
@@ -87,9 +85,11 @@ bool StrengthReductionPass::ReplaceMultiplyByPowerOf2(
|
||||
modified = true;
|
||||
uint32_t shiftAmount = CountTrailingZeros(constVal);
|
||||
uint32_t shiftConstResultId = GetConstantId(shiftAmount);
|
||||
if (shiftConstResultId == 0) return Status::Failure;
|
||||
|
||||
// Create the new instruction.
|
||||
uint32_t newResultId = TakeNextId();
|
||||
if (newResultId == 0) return Status::Failure;
|
||||
std::vector<Operand> newOperands;
|
||||
newOperands.push_back((*inst)->GetInOperand(1 - i));
|
||||
Operand shiftOperand(spv_operand_type_t::SPV_OPERAND_TYPE_ID,
|
||||
@@ -117,7 +117,7 @@ bool StrengthReductionPass::ReplaceMultiplyByPowerOf2(
|
||||
}
|
||||
}
|
||||
|
||||
return modified;
|
||||
return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange;
|
||||
}
|
||||
|
||||
void StrengthReductionPass::FindIntTypesAndConstants() {
|
||||
@@ -152,6 +152,7 @@ uint32_t StrengthReductionPass::GetConstantId(uint32_t val) {
|
||||
|
||||
// Construct the constant.
|
||||
uint32_t resultId = TakeNextId();
|
||||
if (resultId == 0) return 0;
|
||||
Operand constant(spv_operand_type_t::SPV_OPERAND_TYPE_LITERAL_INTEGER,
|
||||
{val});
|
||||
std::unique_ptr<Instruction> newConstant(new Instruction(
|
||||
@@ -169,7 +170,7 @@ uint32_t StrengthReductionPass::GetConstantId(uint32_t val) {
|
||||
return constant_ids_[val];
|
||||
}
|
||||
|
||||
bool StrengthReductionPass::ScanFunctions() {
|
||||
Pass::Status StrengthReductionPass::ScanFunctions() {
|
||||
// I did not use |ForEachInst| in the module because the function that acts on
|
||||
// the instruction gets a pointer to the instruction. We cannot use that to
|
||||
// insert a new instruction. I want an iterator.
|
||||
@@ -178,16 +179,19 @@ bool StrengthReductionPass::ScanFunctions() {
|
||||
for (auto& bb : func) {
|
||||
for (auto inst = bb.begin(); inst != bb.end(); ++inst) {
|
||||
switch (inst->opcode()) {
|
||||
case spv::Op::OpIMul:
|
||||
if (ReplaceMultiplyByPowerOf2(&inst)) modified = true;
|
||||
case spv::Op::OpIMul: {
|
||||
Status s = ReplaceMultiplyByPowerOf2(&inst);
|
||||
if (s == Status::Failure) return Status::Failure;
|
||||
if (s == Status::SuccessWithChange) modified = true;
|
||||
break;
|
||||
}
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return modified;
|
||||
return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange;
|
||||
}
|
||||
|
||||
} // namespace opt
|
||||
|
||||
@@ -32,7 +32,7 @@ class StrengthReductionPass : public Pass {
|
||||
private:
|
||||
// Replaces multiple by power of 2 with an equivalent bit shift.
|
||||
// Returns true if something changed.
|
||||
bool ReplaceMultiplyByPowerOf2(BasicBlock::iterator*);
|
||||
Status ReplaceMultiplyByPowerOf2(BasicBlock::iterator*);
|
||||
|
||||
// Scan the types and constants in the module looking for the integer
|
||||
// types that we are
|
||||
@@ -47,7 +47,7 @@ class StrengthReductionPass : public Pass {
|
||||
|
||||
// Replaces certain instructions in function bodies with presumably cheaper
|
||||
// ones. Returns true if something changed.
|
||||
bool ScanFunctions();
|
||||
Status ScanFunctions();
|
||||
|
||||
// Type ids for the types of interest, or 0 if they do not exist.
|
||||
uint32_t int32_type_id_;
|
||||
|
||||
@@ -427,20 +427,20 @@ Handler_OpImageSparseRead_StorageImageReadWithoutFormat(
|
||||
// Opcode of interest to determine capabilities requirements.
|
||||
constexpr std::array<std::pair<spv::Op, OpcodeHandler>, 14> kOpcodeHandlers{{
|
||||
// clang-format off
|
||||
{spv::Op::OpImageRead, Handler_OpImageRead_StorageImageReadWithoutFormat},
|
||||
{spv::Op::OpImageWrite, Handler_OpImageWrite_StorageImageWriteWithoutFormat},
|
||||
{spv::Op::OpImageSparseRead, Handler_OpImageSparseRead_StorageImageReadWithoutFormat},
|
||||
{spv::Op::OpTypeFloat, Handler_OpTypeFloat_Float16 },
|
||||
{spv::Op::OpTypeFloat, Handler_OpTypeFloat_Float64 },
|
||||
{spv::Op::OpTypeImage, Handler_OpTypeImage_ImageMSArray},
|
||||
{spv::Op::OpTypeInt, Handler_OpTypeInt_Int16 },
|
||||
{spv::Op::OpTypeInt, Handler_OpTypeInt_Int64 },
|
||||
{spv::Op::OpTypePointer, Handler_OpTypePointer_StorageInputOutput16},
|
||||
{spv::Op::OpTypePointer, Handler_OpTypePointer_StoragePushConstant16},
|
||||
{spv::Op::OpTypePointer, Handler_OpTypePointer_StorageUniform16},
|
||||
{spv::Op::OpTypePointer, Handler_OpTypePointer_StorageUniform16},
|
||||
{spv::Op::OpTypePointer, Handler_OpTypePointer_StorageUniformBufferBlock16},
|
||||
{spv::Op::OpTypePointer, Handler_OpTypePointer_StorageBuffer16BitAccess},
|
||||
{spv::Op::OpImageRead, Handler_OpImageRead_StorageImageReadWithoutFormat},
|
||||
{spv::Op::OpImageWrite, Handler_OpImageWrite_StorageImageWriteWithoutFormat},
|
||||
{spv::Op::OpImageSparseRead, Handler_OpImageSparseRead_StorageImageReadWithoutFormat},
|
||||
{spv::Op::OpTypeFloat, Handler_OpTypeFloat_Float16 },
|
||||
{spv::Op::OpTypeFloat, Handler_OpTypeFloat_Float64 },
|
||||
{spv::Op::OpTypeImage, Handler_OpTypeImage_ImageMSArray},
|
||||
{spv::Op::OpTypeInt, Handler_OpTypeInt_Int16 },
|
||||
{spv::Op::OpTypeInt, Handler_OpTypeInt_Int64 },
|
||||
{spv::Op::OpTypePointer, Handler_OpTypePointer_StorageInputOutput16},
|
||||
{spv::Op::OpTypePointer, Handler_OpTypePointer_StoragePushConstant16},
|
||||
{spv::Op::OpTypePointer, Handler_OpTypePointer_StorageUniform16},
|
||||
{spv::Op::OpTypePointer, Handler_OpTypePointer_StorageUniform16},
|
||||
{spv::Op::OpTypePointer, Handler_OpTypePointer_StorageUniformBufferBlock16},
|
||||
{spv::Op::OpTypePointer, Handler_OpTypePointer_StorageBuffer16BitAccess},
|
||||
// clang-format on
|
||||
}};
|
||||
|
||||
@@ -612,7 +612,9 @@ void TrimCapabilitiesPass::addInstructionRequirements(
|
||||
ExtensionSet* extensions) const {
|
||||
// Ignoring OpCapability and OpExtension instructions.
|
||||
if (instruction->opcode() == spv::Op::OpCapability ||
|
||||
instruction->opcode() == spv::Op::OpExtension) {
|
||||
instruction->opcode() == spv::Op::OpConditionalCapabilityINTEL ||
|
||||
instruction->opcode() == spv::Op::OpExtension ||
|
||||
instruction->opcode() == spv::Op::OpConditionalExtensionINTEL) {
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -631,7 +633,7 @@ void TrimCapabilitiesPass::addInstructionRequirements(
|
||||
}
|
||||
|
||||
// Last case: some complex logic needs to be run to determine capabilities.
|
||||
auto[begin, end] = opcodeHandlers_.equal_range(instruction->opcode());
|
||||
auto [begin, end] = opcodeHandlers_.equal_range(instruction->opcode());
|
||||
for (auto it = begin; it != end; it++) {
|
||||
const OpcodeHandler handler = it->second;
|
||||
auto result = handler(instruction);
|
||||
@@ -754,7 +756,7 @@ Pass::Status TrimCapabilitiesPass::Process() {
|
||||
return Status::SuccessWithoutChange;
|
||||
}
|
||||
|
||||
auto[required_capabilities, required_extensions] =
|
||||
auto [required_capabilities, required_extensions] =
|
||||
DetermineRequiredCapabilitiesAndExtensions();
|
||||
|
||||
Pass::Status capStatus = TrimUnrequiredCapabilities(required_capabilities);
|
||||
|
||||
@@ -82,6 +82,7 @@ class TrimCapabilitiesPass : public Pass {
|
||||
spv::Capability::FragmentShaderPixelInterlockEXT,
|
||||
spv::Capability::FragmentShaderSampleInterlockEXT,
|
||||
spv::Capability::FragmentShaderShadingRateInterlockEXT,
|
||||
spv::Capability::Geometry,
|
||||
spv::Capability::GroupNonUniform,
|
||||
spv::Capability::GroupNonUniformArithmetic,
|
||||
spv::Capability::GroupNonUniformClustered,
|
||||
|
||||
91
3rdparty/spirv-tools/source/opt/type_manager.cpp
vendored
91
3rdparty/spirv-tools/source/opt/type_manager.cpp
vendored
@@ -495,6 +495,49 @@ uint32_t TypeManager::GetTypeInstruction(const Type* type) {
|
||||
{SPV_OPERAND_TYPE_ID, {coop_vec->components()}}});
|
||||
break;
|
||||
}
|
||||
case Type::kTensorARM: {
|
||||
auto tensor_type = type->AsTensorARM();
|
||||
uint32_t const element_type =
|
||||
GetTypeInstruction(tensor_type->element_type());
|
||||
if (element_type == 0) {
|
||||
return 0;
|
||||
}
|
||||
if (tensor_type->rank_id() != 0) {
|
||||
if (tensor_type->shape_id() != 0) {
|
||||
typeInst = MakeUnique<Instruction>(
|
||||
context(), spv::Op::OpTypeTensorARM, 0, id,
|
||||
std::initializer_list<Operand>{
|
||||
{SPV_OPERAND_TYPE_ID, {element_type}},
|
||||
{SPV_OPERAND_TYPE_ID, {tensor_type->rank_id()}},
|
||||
{SPV_OPERAND_TYPE_ID, {tensor_type->shape_id()}}});
|
||||
} else {
|
||||
typeInst = MakeUnique<Instruction>(
|
||||
context(), spv::Op::OpTypeTensorARM, 0, id,
|
||||
std::initializer_list<Operand>{
|
||||
{SPV_OPERAND_TYPE_ID, {element_type}},
|
||||
{SPV_OPERAND_TYPE_ID, {tensor_type->rank_id()}}});
|
||||
}
|
||||
} else {
|
||||
typeInst =
|
||||
MakeUnique<Instruction>(context(), spv::Op::OpTypeTensorARM, 0, id,
|
||||
std::initializer_list<Operand>{
|
||||
{SPV_OPERAND_TYPE_ID, {element_type}}});
|
||||
}
|
||||
break;
|
||||
}
|
||||
case Type::kGraphARM: {
|
||||
auto const gty = type->AsGraphARM();
|
||||
std::vector<Operand> ops;
|
||||
ops.push_back(
|
||||
Operand(SPV_OPERAND_TYPE_LITERAL_INTEGER, {gty->num_inputs()}));
|
||||
for (auto iotype : gty->io_types()) {
|
||||
uint32_t iotype_id = GetTypeInstruction(iotype);
|
||||
ops.push_back(Operand(SPV_OPERAND_TYPE_ID, {iotype_id}));
|
||||
}
|
||||
typeInst = MakeUnique<Instruction>(context(), spv::Op::OpTypeGraphARM, 0,
|
||||
id, ops);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
assert(false && "Unexpected type");
|
||||
break;
|
||||
@@ -754,6 +797,23 @@ Type* TypeManager::RebuildType(uint32_t type_id, const Type& type) {
|
||||
cv_type->components());
|
||||
break;
|
||||
}
|
||||
case Type::kTensorARM: {
|
||||
const TensorARM* tensor_type = type.AsTensorARM();
|
||||
const Type* element_type = tensor_type->element_type();
|
||||
rebuilt_ty = MakeUnique<TensorARM>(
|
||||
RebuildType(GetId(element_type), *element_type),
|
||||
tensor_type->rank_id(), tensor_type->shape_id());
|
||||
break;
|
||||
}
|
||||
case Type::kGraphARM: {
|
||||
const GraphARM* graph_type = type.AsGraphARM();
|
||||
std::vector<const Type*> io_types;
|
||||
for (auto ioty : graph_type->io_types()) {
|
||||
io_types.push_back(RebuildType(GetId(ioty), *ioty));
|
||||
}
|
||||
rebuilt_ty = MakeUnique<GraphARM>(graph_type->num_inputs(), io_types);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
assert(false && "Unhandled type");
|
||||
return nullptr;
|
||||
@@ -1036,6 +1096,31 @@ Type* TypeManager::RecordIfTypeDefinition(const Instruction& inst) {
|
||||
inst.GetSingleWordInOperand(1), perm);
|
||||
break;
|
||||
}
|
||||
case spv::Op::OpTypeTensorARM: {
|
||||
switch (inst.NumInOperands()) {
|
||||
case 1:
|
||||
type = new TensorARM(GetType(inst.GetSingleWordInOperand(0)));
|
||||
break;
|
||||
case 2:
|
||||
type = new TensorARM(GetType(inst.GetSingleWordInOperand(0)),
|
||||
inst.GetSingleWordInOperand(1));
|
||||
break;
|
||||
case 3:
|
||||
type = new TensorARM(GetType(inst.GetSingleWordInOperand(0)),
|
||||
inst.GetSingleWordInOperand(1),
|
||||
inst.GetSingleWordInOperand(2));
|
||||
break;
|
||||
}
|
||||
break;
|
||||
}
|
||||
case spv::Op::OpTypeGraphARM: {
|
||||
std::vector<const Type*> io_types;
|
||||
for (unsigned i = 1; i < inst.NumInOperands(); i++) {
|
||||
io_types.push_back(GetType(inst.GetSingleWordInOperand(i)));
|
||||
}
|
||||
type = new GraphARM(inst.GetSingleWordInOperand(0), io_types);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
assert(false && "Type not handled by the type manager.");
|
||||
break;
|
||||
@@ -1067,7 +1152,11 @@ void TypeManager::AttachDecoration(const Instruction& inst, Type* type) {
|
||||
const auto count = inst.NumOperands();
|
||||
std::vector<uint32_t> data;
|
||||
for (uint32_t i = 1; i < count; ++i) {
|
||||
data.push_back(inst.GetSingleWordOperand(i));
|
||||
// LinkageAttributes has a literal string as an operand, which is a
|
||||
// varible length word. We cannot assume that all operands are single
|
||||
// word.
|
||||
const Operand::OperandData& words = inst.GetOperand(i).words;
|
||||
data.insert(data.end(), words.begin(), words.end());
|
||||
}
|
||||
type->AddDecoration(std::move(data));
|
||||
} break;
|
||||
|
||||
91
3rdparty/spirv-tools/source/opt/types.cpp
vendored
91
3rdparty/spirv-tools/source/opt/types.cpp
vendored
@@ -135,6 +135,8 @@ std::unique_ptr<Type> Type::Clone() const {
|
||||
DeclareKindCase(CooperativeVectorNV);
|
||||
DeclareKindCase(RayQueryKHR);
|
||||
DeclareKindCase(HitObjectNV);
|
||||
DeclareKindCase(TensorARM);
|
||||
DeclareKindCase(GraphARM);
|
||||
#undef DeclareKindCase
|
||||
default:
|
||||
assert(false && "Unhandled type");
|
||||
@@ -187,6 +189,8 @@ bool Type::operator==(const Type& other) const {
|
||||
DeclareKindCase(HitObjectNV);
|
||||
DeclareKindCase(TensorLayoutNV);
|
||||
DeclareKindCase(TensorViewNV);
|
||||
DeclareKindCase(TensorARM);
|
||||
DeclareKindCase(GraphARM);
|
||||
#undef DeclareKindCase
|
||||
default:
|
||||
assert(false && "Unhandled type");
|
||||
@@ -247,6 +251,8 @@ size_t Type::ComputeHashValue(size_t hash, SeenTypes* seen) const {
|
||||
DeclareKindCase(HitObjectNV);
|
||||
DeclareKindCase(TensorLayoutNV);
|
||||
DeclareKindCase(TensorViewNV);
|
||||
DeclareKindCase(TensorARM);
|
||||
DeclareKindCase(GraphARM);
|
||||
#undef DeclareKindCase
|
||||
default:
|
||||
assert(false && "Unhandled type");
|
||||
@@ -899,6 +905,91 @@ bool CooperativeVectorNV::IsSameImpl(const Type* that,
|
||||
components_ == mt->components_ && HasSameDecorations(that);
|
||||
}
|
||||
|
||||
TensorARM::TensorARM(const Type* elty, const uint32_t rank,
|
||||
const uint32_t shape)
|
||||
: Type(kTensorARM), element_type_(elty), rank_id_(rank), shape_id_(shape) {
|
||||
assert(elty != nullptr);
|
||||
if (shape != 0) {
|
||||
assert(rank != 0);
|
||||
}
|
||||
}
|
||||
|
||||
std::string TensorARM::str() const {
|
||||
std::ostringstream oss;
|
||||
oss << "tensor<" << element_type_->str() << ", id(" << rank_id_ << "), id("
|
||||
<< shape_id_ << ")>";
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
size_t TensorARM::ComputeExtraStateHash(size_t hash, SeenTypes* seen) const {
|
||||
hash = hash_combine(hash, rank_id_);
|
||||
hash = hash_combine(hash, shape_id_);
|
||||
return element_type_->ComputeHashValue(hash, seen);
|
||||
}
|
||||
|
||||
bool TensorARM::IsSameImpl(const Type* that, IsSameCache* seen) const {
|
||||
const TensorARM* tt = that->AsTensorARM();
|
||||
if (!tt) return false;
|
||||
return element_type_->IsSameImpl(tt->element_type_, seen) &&
|
||||
rank_id_ == tt->rank_id_ && shape_id_ == tt->shape_id_ &&
|
||||
HasSameDecorations(that);
|
||||
}
|
||||
|
||||
GraphARM::GraphARM(const uint32_t num_inputs,
|
||||
const std::vector<const Type*>& io_types)
|
||||
: Type(kGraphARM), num_inputs_(num_inputs), io_types_(io_types) {
|
||||
assert(io_types.size() > 0);
|
||||
}
|
||||
|
||||
std::string GraphARM::str() const {
|
||||
std::ostringstream oss;
|
||||
oss << "graph<" << num_inputs_;
|
||||
for (auto ioty : io_types_) {
|
||||
oss << "," << ioty->str();
|
||||
}
|
||||
oss << ">";
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
bool GraphARM::is_shaped() const {
|
||||
// A graph is considered to be shaped if all its interface tensors are shaped
|
||||
for (auto ioty : io_types_) {
|
||||
auto tensor_type = ioty->AsTensorARM();
|
||||
assert(tensor_type);
|
||||
if (!tensor_type->is_shaped()) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
size_t GraphARM::ComputeExtraStateHash(size_t hash, SeenTypes* seen) const {
|
||||
hash = hash_combine(hash, num_inputs_);
|
||||
for (auto ioty : io_types_) {
|
||||
hash = ioty->ComputeHashValue(hash, seen);
|
||||
}
|
||||
return hash;
|
||||
}
|
||||
|
||||
bool GraphARM::IsSameImpl(const Type* that, IsSameCache* seen) const {
|
||||
const GraphARM* og = that->AsGraphARM();
|
||||
if (!og) {
|
||||
return false;
|
||||
}
|
||||
if (num_inputs_ != og->num_inputs_) {
|
||||
return false;
|
||||
}
|
||||
if (io_types_.size() != og->io_types_.size()) {
|
||||
return false;
|
||||
}
|
||||
for (size_t i = 0; i < io_types_.size(); i++) {
|
||||
if (!io_types_[i]->IsSameImpl(og->io_types_[i], seen)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace analysis
|
||||
} // namespace opt
|
||||
} // namespace spvtools
|
||||
|
||||
56
3rdparty/spirv-tools/source/opt/types.h
vendored
56
3rdparty/spirv-tools/source/opt/types.h
vendored
@@ -69,6 +69,8 @@ class RayQueryKHR;
|
||||
class HitObjectNV;
|
||||
class TensorLayoutNV;
|
||||
class TensorViewNV;
|
||||
class TensorARM;
|
||||
class GraphARM;
|
||||
|
||||
// Abstract class for a SPIR-V type. It has a bunch of As<sublcass>() methods,
|
||||
// which is used as a way to probe the actual <subclass>.
|
||||
@@ -114,6 +116,8 @@ class Type {
|
||||
kHitObjectNV,
|
||||
kTensorLayoutNV,
|
||||
kTensorViewNV,
|
||||
kTensorARM,
|
||||
kGraphARM,
|
||||
kLast
|
||||
};
|
||||
|
||||
@@ -220,6 +224,8 @@ class Type {
|
||||
DeclareCastMethod(HitObjectNV)
|
||||
DeclareCastMethod(TensorLayoutNV)
|
||||
DeclareCastMethod(TensorViewNV)
|
||||
DeclareCastMethod(TensorARM)
|
||||
DeclareCastMethod(GraphARM)
|
||||
#undef DeclareCastMethod
|
||||
|
||||
protected:
|
||||
@@ -774,6 +780,56 @@ class CooperativeVectorNV : public Type {
|
||||
const uint32_t components_;
|
||||
};
|
||||
|
||||
class TensorARM : public Type {
|
||||
public:
|
||||
TensorARM(const Type* elty, const uint32_t rank = 0,
|
||||
const uint32_t shape = 0);
|
||||
TensorARM(const TensorARM&) = default;
|
||||
|
||||
std::string str() const override;
|
||||
|
||||
TensorARM* AsTensorARM() override { return this; }
|
||||
const TensorARM* AsTensorARM() const override { return this; }
|
||||
|
||||
size_t ComputeExtraStateHash(size_t hash, SeenTypes* seen) const override;
|
||||
|
||||
const Type* element_type() const { return element_type_; }
|
||||
uint32_t rank_id() const { return rank_id_; }
|
||||
uint32_t shape_id() const { return shape_id_; }
|
||||
bool is_ranked() const { return rank_id_ != 0; }
|
||||
bool is_shaped() const { return shape_id_ != 0; }
|
||||
|
||||
private:
|
||||
bool IsSameImpl(const Type* that, IsSameCache*) const override;
|
||||
|
||||
const Type* element_type_;
|
||||
const uint32_t rank_id_;
|
||||
const uint32_t shape_id_;
|
||||
};
|
||||
|
||||
class GraphARM : public Type {
|
||||
public:
|
||||
GraphARM(const uint32_t num_inputs, const std::vector<const Type*>& io_types);
|
||||
GraphARM(const GraphARM&) = default;
|
||||
|
||||
std::string str() const override;
|
||||
|
||||
GraphARM* AsGraphARM() override { return this; }
|
||||
const GraphARM* AsGraphARM() const override { return this; }
|
||||
|
||||
uint32_t num_inputs() const { return num_inputs_; }
|
||||
const std::vector<const Type*>& io_types() const { return io_types_; }
|
||||
bool is_shaped() const;
|
||||
|
||||
size_t ComputeExtraStateHash(size_t hash, SeenTypes* seen) const override;
|
||||
|
||||
private:
|
||||
bool IsSameImpl(const Type* that, IsSameCache*) const override;
|
||||
|
||||
const uint32_t num_inputs_;
|
||||
const std::vector<const Type*> io_types_;
|
||||
};
|
||||
|
||||
#define DefineParameterlessType(type, name) \
|
||||
class type : public Type { \
|
||||
public: \
|
||||
|
||||
@@ -160,14 +160,38 @@ void UpgradeMemoryModel::UpgradeMemoryAndImages() {
|
||||
}
|
||||
|
||||
switch (inst->opcode()) {
|
||||
case spv::Op::OpLoad:
|
||||
case spv::Op::OpLoad: {
|
||||
Instruction* src_pointer = context()->get_def_use_mgr()->GetDef(
|
||||
inst->GetSingleWordInOperand(0u));
|
||||
analysis::Type* src_type =
|
||||
context()->get_type_mgr()->GetType(src_pointer->type_id());
|
||||
auto storage_class = src_type->AsPointer()->storage_class();
|
||||
if (storage_class == spv::StorageClass::Function ||
|
||||
storage_class == spv::StorageClass::Private) {
|
||||
// If the buffer from function variable or private variable, flag
|
||||
// NonPrivatePointer is unnecessary.
|
||||
is_coherent = false;
|
||||
}
|
||||
UpgradeFlags(inst, 1u, is_coherent, is_volatile, kVisibility,
|
||||
kMemory);
|
||||
break;
|
||||
case spv::Op::OpStore:
|
||||
}
|
||||
case spv::Op::OpStore: {
|
||||
Instruction* src_pointer = context()->get_def_use_mgr()->GetDef(
|
||||
inst->GetSingleWordInOperand(0u));
|
||||
analysis::Type* src_type =
|
||||
context()->get_type_mgr()->GetType(src_pointer->type_id());
|
||||
auto storage_class = src_type->AsPointer()->storage_class();
|
||||
if (storage_class == spv::StorageClass::Function ||
|
||||
storage_class == spv::StorageClass::Private) {
|
||||
// If the buffer from function variable or private variable, flag
|
||||
// NonPrivatePointer is unnecessary.
|
||||
is_coherent = false;
|
||||
}
|
||||
UpgradeFlags(inst, 2u, is_coherent, is_volatile, kAvailability,
|
||||
kMemory);
|
||||
break;
|
||||
}
|
||||
case spv::Op::OpCopyMemory:
|
||||
case spv::Op::OpCopyMemorySized:
|
||||
start_operand = inst->opcode() == spv::Op::OpCopyMemory ? 2u : 3u;
|
||||
@@ -366,6 +390,21 @@ std::pair<bool, bool> UpgradeMemoryModel::TraceInstruction(
|
||||
indices.push_back(inst->GetSingleWordInOperand(i));
|
||||
}
|
||||
break;
|
||||
case spv::Op::OpLoad:
|
||||
if (context()->get_type_mgr()->GetType(inst->type_id())->AsPointer()) {
|
||||
analysis::Integer int_ty(32, false);
|
||||
uint32_t int_id =
|
||||
context()->get_type_mgr()->GetTypeInstruction(&int_ty);
|
||||
const analysis::Constant* constant =
|
||||
context()->get_constant_mgr()->GetConstant(
|
||||
context()->get_type_mgr()->GetType(int_id), {0u});
|
||||
uint32_t constant_id = context()
|
||||
->get_constant_mgr()
|
||||
->GetDefiningInstruction(constant)
|
||||
->result_id();
|
||||
|
||||
indices.push_back(constant_id);
|
||||
}
|
||||
default:
|
||||
break;
|
||||
}
|
||||
@@ -661,22 +700,29 @@ void UpgradeMemoryModel::UpgradeBarriers() {
|
||||
roots.push(e.GetSingleWordInOperand(1u));
|
||||
if (context()->ProcessCallTreeFromRoots(CollectBarriers, &roots)) {
|
||||
for (auto barrier : barriers) {
|
||||
// Add OutputMemoryKHR to the semantics of the barriers.
|
||||
// Add OutputMemoryKHR to the semantics of the non-relaxed barriers.
|
||||
uint32_t semantics_id = barrier->GetSingleWordInOperand(2u);
|
||||
Instruction* semantics_inst =
|
||||
context()->get_def_use_mgr()->GetDef(semantics_id);
|
||||
analysis::Type* semantics_type =
|
||||
context()->get_type_mgr()->GetType(semantics_inst->type_id());
|
||||
uint64_t semantics_value = GetIndexValue(semantics_inst);
|
||||
const analysis::Constant* constant =
|
||||
context()->get_constant_mgr()->GetConstant(
|
||||
semantics_type,
|
||||
{static_cast<uint32_t>(semantics_value) |
|
||||
uint32_t(spv::MemorySemanticsMask::OutputMemoryKHR)});
|
||||
barrier->SetInOperand(2u, {context()
|
||||
->get_constant_mgr()
|
||||
->GetDefiningInstruction(constant)
|
||||
->result_id()});
|
||||
const uint64_t memory_order_mask =
|
||||
uint64_t(spv::MemorySemanticsMask::Acquire |
|
||||
spv::MemorySemanticsMask::Release |
|
||||
spv::MemorySemanticsMask::AcquireRelease |
|
||||
spv::MemorySemanticsMask::SequentiallyConsistent);
|
||||
if (semantics_value & memory_order_mask) {
|
||||
const analysis::Constant* constant =
|
||||
context()->get_constant_mgr()->GetConstant(
|
||||
semantics_type,
|
||||
{static_cast<uint32_t>(semantics_value) |
|
||||
uint32_t(spv::MemorySemanticsMask::OutputMemoryKHR)});
|
||||
barrier->SetInOperand(2u, {context()
|
||||
->get_constant_mgr()
|
||||
->GetDefiningInstruction(constant)
|
||||
->result_id()});
|
||||
}
|
||||
}
|
||||
}
|
||||
barriers.clear();
|
||||
|
||||
@@ -59,7 +59,10 @@ void EmitNumericLiteral(std::ostream* out, const spv_parsed_instruction_t& inst,
|
||||
*out << spvtools::utils::FloatProxy<spvtools::utils::Float8_E5M2>(
|
||||
uint8_t(word & 0xFF));
|
||||
break;
|
||||
// TODO Bfloat16
|
||||
case SPV_FP_ENCODING_BFLOAT16:
|
||||
*out << spvtools::utils::FloatProxy<spvtools::utils::BFloat16>(
|
||||
uint16_t(word & 0xFFFF));
|
||||
break;
|
||||
case SPV_FP_ENCODING_UNKNOWN:
|
||||
switch (operand.number_bit_width) {
|
||||
case 16:
|
||||
|
||||
2
3rdparty/spirv-tools/source/text_handler.cpp
vendored
2
3rdparty/spirv-tools/source/text_handler.cpp
vendored
@@ -336,7 +336,7 @@ spv_result_t AssemblyContext::recordTypeDefinition(
|
||||
return diagnostic() << "Invalid OpTypeFloat instruction";
|
||||
spv_fp_encoding_t enc = SPV_FP_ENCODING_UNKNOWN;
|
||||
if (pInst->words.size() >= 4) {
|
||||
const spvtools::OperandDesc* desc;
|
||||
const spvtools::OperandDesc* desc = nullptr;
|
||||
spv_result_t status = spvtools::LookupOperand(SPV_OPERAND_TYPE_FPENCODING,
|
||||
pInst->words[3], &desc);
|
||||
if (status == SPV_SUCCESS) {
|
||||
|
||||
90
3rdparty/spirv-tools/source/util/hex_float.h
vendored
90
3rdparty/spirv-tools/source/util/hex_float.h
vendored
@@ -103,6 +103,34 @@ class Float16 {
|
||||
uint16_t val;
|
||||
};
|
||||
|
||||
class BFloat16 {
|
||||
public:
|
||||
BFloat16(uint16_t v) : val(v) {}
|
||||
BFloat16() = default;
|
||||
BFloat16(const BFloat16& other) { val = other.val; }
|
||||
|
||||
// Exponent mask: 0x7F80, Mantissa mask: 0x007F
|
||||
static bool isNan(const BFloat16& val) {
|
||||
return ((val.val & 0x7F80) == 0x7F80) && ((val.val & 0x007F) != 0);
|
||||
}
|
||||
static bool isInfinity(const BFloat16& val) {
|
||||
return ((val.val & 0x7F80) == 0x7F80) && ((val.val & 0x007F) == 0);
|
||||
}
|
||||
|
||||
uint16_t get_value() const { return val; }
|
||||
|
||||
// a sign bit of 0, and an all 1 mantissa.
|
||||
static BFloat16 max() { return BFloat16(0x7F7F); }
|
||||
// a sign bit of 1, and an all 1 mantissa.
|
||||
static BFloat16 lowest() { return BFloat16(0xFF7F); }
|
||||
|
||||
private:
|
||||
// 15: Sign
|
||||
// 14-7: Exponent
|
||||
// 6-0: Mantissa
|
||||
uint16_t val;
|
||||
};
|
||||
|
||||
// To specialize this type, you must override uint_type to define
|
||||
// an unsigned integer that can fit your floating point type.
|
||||
// You must also add a isNan function that returns true if
|
||||
@@ -212,6 +240,24 @@ struct FloatProxyTraits<Float16> {
|
||||
static uint32_t width() { return 16u; }
|
||||
};
|
||||
|
||||
template <>
|
||||
struct FloatProxyTraits<BFloat16> {
|
||||
using uint_type = uint16_t;
|
||||
static bool isNan(BFloat16 f) { return BFloat16::isNan(f); }
|
||||
// Returns true if the given value is any kind of infinity.
|
||||
static bool isInfinity(BFloat16 f) { return BFloat16::isInfinity(f); }
|
||||
// Returns the maximum normal value.
|
||||
static BFloat16 max() { return BFloat16::max(); }
|
||||
// Returns the lowest normal value.
|
||||
static BFloat16 lowest() { return BFloat16::lowest(); }
|
||||
// Returns the value as the native floating point format.
|
||||
static BFloat16 getAsFloat(const uint_type& t) { return BFloat16(t); }
|
||||
// Returns the bits from the given floating pointer number.
|
||||
static uint_type getBitsFromFloat(const BFloat16& t) { return t.get_value(); }
|
||||
// Returns the bitwidth.
|
||||
static uint32_t width() { return 16u; }
|
||||
};
|
||||
|
||||
// Since copying a floating point number (especially if it is NaN)
|
||||
// does not guarantee that bits are preserved, this class lets us
|
||||
// store the type and use it as a float when necessary.
|
||||
@@ -403,6 +449,23 @@ struct HexFloatTraits<FloatProxy<Float16>> {
|
||||
static const uint_type NaN_pattern = 0x7c00;
|
||||
};
|
||||
|
||||
// Traits for BFloat16.
|
||||
// 1 sign bit, 7 exponent bits, 8 fractional bits.
|
||||
template <>
|
||||
struct HexFloatTraits<FloatProxy<BFloat16>> {
|
||||
using uint_type = uint16_t;
|
||||
using int_type = int16_t;
|
||||
using underlying_type = FloatProxy<BFloat16>;
|
||||
using underlying_typetraits = FloatProxyTraits<BFloat16>;
|
||||
using native_type = uint16_t;
|
||||
static const uint_type num_used_bits = 16;
|
||||
static const uint_type num_exponent_bits = 8;
|
||||
static const uint_type num_fraction_bits = 7;
|
||||
static const uint_type exponent_bias = 127;
|
||||
static const bool has_infinity = true;
|
||||
static const uint_type NaN_pattern = 0x7F80;
|
||||
};
|
||||
|
||||
enum class round_direction {
|
||||
kToZero,
|
||||
kToNearestEven,
|
||||
@@ -1038,6 +1101,26 @@ ParseNormalFloat<FloatProxy<Float16>, HexFloatTraits<FloatProxy<Float16>>>(
|
||||
}
|
||||
return is;
|
||||
}
|
||||
|
||||
// Same flow as Float16
|
||||
template <>
|
||||
inline std::istream&
|
||||
ParseNormalFloat<FloatProxy<BFloat16>, HexFloatTraits<FloatProxy<BFloat16>>>(
|
||||
std::istream& is, bool negate_value,
|
||||
HexFloat<FloatProxy<BFloat16>, HexFloatTraits<FloatProxy<BFloat16>>>&
|
||||
value) {
|
||||
HexFloat<FloatProxy<float>> float_val(0.0f);
|
||||
ParseNormalFloat(is, negate_value, float_val);
|
||||
|
||||
float_val.castTo(value, round_direction::kToZero);
|
||||
|
||||
if (BFloat16::isInfinity(value.value().getAsFloat())) {
|
||||
value.set_value(value.isNegative() ? BFloat16::lowest() : BFloat16::max());
|
||||
is.setstate(std::ios_base::failbit);
|
||||
}
|
||||
return is;
|
||||
}
|
||||
|
||||
// Specialization of ParseNormalFloat for FloatProxy<Float8_E4M3> values.
|
||||
// This will parse the float as it were a 32-bit floating point number,
|
||||
// and then round it down to fit into a Float8_E4M3 value.
|
||||
@@ -1468,6 +1551,13 @@ inline std::ostream& operator<<<Float16>(std::ostream& os,
|
||||
return os;
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::ostream& operator<< <BFloat16>(std::ostream& os,
|
||||
const FloatProxy<BFloat16>& value) {
|
||||
os << HexFloat<FloatProxy<BFloat16>>(value);
|
||||
return os;
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::ostream& operator<< <Float8_E4M3>(
|
||||
std::ostream& os, const FloatProxy<Float8_E4M3>& value) {
|
||||
|
||||
@@ -185,7 +185,15 @@ EncodeNumberStatus ParseAndEncodeFloatingPointNumber(
|
||||
emit(static_cast<uint32_t>(hVal.value().getAsFloat().get_value()));
|
||||
return EncodeNumberStatus::kSuccess;
|
||||
} break;
|
||||
case SPV_FP_ENCODING_BFLOAT16: // FIXME this likely needs separate handling
|
||||
case SPV_FP_ENCODING_BFLOAT16: {
|
||||
HexFloat<FloatProxy<BFloat16>> hVal(0);
|
||||
if (!ParseNumber(text, &hVal)) {
|
||||
ErrorMsgStream(error_msg) << "Invalid bfloat16 literal: " << text;
|
||||
return EncodeNumberStatus::kInvalidText;
|
||||
}
|
||||
emit(static_cast<uint32_t>(hVal.value().getAsFloat().get_value()));
|
||||
return EncodeNumberStatus::kSuccess;
|
||||
} break;
|
||||
case SPV_FP_ENCODING_IEEE754_BINARY16: {
|
||||
HexFloat<FloatProxy<Float16>> hVal(0);
|
||||
if (!ParseNumber(text, &hVal)) {
|
||||
|
||||
80
3rdparty/spirv-tools/source/val/validate.cpp
vendored
80
3rdparty/spirv-tools/source/val/validate.cpp
vendored
@@ -64,9 +64,12 @@ void RegisterExtension(ValidationState_t& _,
|
||||
spv_result_t ProcessExtensions(void* user_data,
|
||||
const spv_parsed_instruction_t* inst) {
|
||||
const spv::Op opcode = static_cast<spv::Op>(inst->opcode);
|
||||
if (opcode == spv::Op::OpCapability) return SPV_SUCCESS;
|
||||
if (opcode == spv::Op::OpCapability ||
|
||||
opcode == spv::Op::OpConditionalCapabilityINTEL)
|
||||
return SPV_SUCCESS;
|
||||
|
||||
if (opcode == spv::Op::OpExtension) {
|
||||
if (opcode == spv::Op::OpExtension ||
|
||||
opcode == spv::Op::OpConditionalExtensionINTEL) {
|
||||
ValidationState_t& _ = *(reinterpret_cast<ValidationState_t*>(user_data));
|
||||
RegisterExtension(_, inst);
|
||||
return SPV_SUCCESS;
|
||||
@@ -115,10 +118,11 @@ spv_result_t ValidateEntryPoints(ValidationState_t& _) {
|
||||
_.ComputeFunctionToEntryPointMapping();
|
||||
_.ComputeRecursiveEntryPoints();
|
||||
|
||||
if (_.entry_points().empty() && !_.HasCapability(spv::Capability::Linkage)) {
|
||||
if (_.entry_points().empty() && !_.HasCapability(spv::Capability::Linkage) &&
|
||||
!_.HasCapability(spv::Capability::GraphARM)) {
|
||||
return _.diag(SPV_ERROR_INVALID_BINARY, nullptr)
|
||||
<< "No OpEntryPoint instruction was found. This is only allowed if "
|
||||
"the Linkage capability is being used.";
|
||||
"the Linkage or GraphARM capability is being used.";
|
||||
}
|
||||
|
||||
for (const auto& entry_point : _.entry_points()) {
|
||||
@@ -151,6 +155,16 @@ spv_result_t ValidateEntryPoints(ValidationState_t& _) {
|
||||
return SPV_SUCCESS;
|
||||
}
|
||||
|
||||
spv_result_t ValidateGraphEntryPoints(ValidationState_t& _) {
|
||||
if (_.graph_entry_points().empty() &&
|
||||
_.HasCapability(spv::Capability::GraphARM)) {
|
||||
return _.diag(SPV_ERROR_INVALID_BINARY, nullptr)
|
||||
<< "No OpGraphEntryPointARM instruction was found but the GraphARM "
|
||||
"capability is declared.";
|
||||
}
|
||||
return SPV_SUCCESS;
|
||||
}
|
||||
|
||||
spv_result_t ValidateBinaryUsingContextAndValidationState(
|
||||
const spv_context_t& context, const uint32_t* words, const size_t num_words,
|
||||
spv_diagnostic* pDiagnostic, ValidationState_t* vstate) {
|
||||
@@ -217,43 +231,59 @@ spv_result_t ValidateBinaryUsingContextAndValidationState(
|
||||
// able to, briefly, de-const the instruction.
|
||||
Instruction* inst = const_cast<Instruction*>(&instruction);
|
||||
|
||||
if (inst->opcode() == spv::Op::OpEntryPoint) {
|
||||
const auto entry_point = inst->GetOperandAs<uint32_t>(1);
|
||||
const auto execution_model = inst->GetOperandAs<spv::ExecutionModel>(0);
|
||||
const std::string desc_name = inst->GetOperandAs<std::string>(2);
|
||||
if ((inst->opcode() == spv::Op::OpEntryPoint) ||
|
||||
(inst->opcode() == spv::Op::OpConditionalEntryPointINTEL)) {
|
||||
const int i_model = inst->opcode() == spv::Op::OpEntryPoint ? 0 : 1;
|
||||
const int i_point = inst->opcode() == spv::Op::OpEntryPoint ? 1 : 2;
|
||||
const int i_name = inst->opcode() == spv::Op::OpEntryPoint ? 2 : 3;
|
||||
const int min_num_operands =
|
||||
inst->opcode() == spv::Op::OpEntryPoint ? 3 : 4;
|
||||
|
||||
const auto entry_point = inst->GetOperandAs<uint32_t>(i_point);
|
||||
const auto execution_model =
|
||||
inst->GetOperandAs<spv::ExecutionModel>(i_model);
|
||||
const std::string desc_name = inst->GetOperandAs<std::string>(i_name);
|
||||
|
||||
ValidationState_t::EntryPointDescription desc;
|
||||
desc.name = desc_name;
|
||||
|
||||
std::vector<uint32_t> interfaces;
|
||||
for (size_t j = 3; j < inst->operands().size(); ++j)
|
||||
for (size_t j = min_num_operands; j < inst->operands().size(); ++j)
|
||||
desc.interfaces.push_back(inst->word(inst->operand(j).offset));
|
||||
|
||||
vstate->RegisterEntryPoint(entry_point, execution_model,
|
||||
std::move(desc));
|
||||
|
||||
if (visited_entry_points.size() > 0) {
|
||||
for (const Instruction* check_inst : visited_entry_points) {
|
||||
const auto check_execution_model =
|
||||
check_inst->GetOperandAs<spv::ExecutionModel>(0);
|
||||
const std::string check_name =
|
||||
check_inst->GetOperandAs<std::string>(2);
|
||||
if (inst->opcode() == spv::Op::OpEntryPoint) {
|
||||
// conditional entry points are allowed to share the same name and
|
||||
// exec mode
|
||||
if (visited_entry_points.size() > 0) {
|
||||
for (const Instruction* check_inst : visited_entry_points) {
|
||||
const auto check_execution_model =
|
||||
check_inst->GetOperandAs<spv::ExecutionModel>(i_model);
|
||||
const std::string check_name =
|
||||
check_inst->GetOperandAs<std::string>(i_name);
|
||||
|
||||
if (desc_name == check_name &&
|
||||
execution_model == check_execution_model) {
|
||||
return vstate->diag(SPV_ERROR_INVALID_DATA, inst)
|
||||
<< "2 Entry points cannot share the same name and "
|
||||
"ExecutionMode.";
|
||||
if (desc_name == check_name &&
|
||||
execution_model == check_execution_model) {
|
||||
return vstate->diag(SPV_ERROR_INVALID_DATA, inst)
|
||||
<< "2 Entry points cannot share the same name and "
|
||||
"ExecutionMode.";
|
||||
}
|
||||
}
|
||||
}
|
||||
visited_entry_points.push_back(inst);
|
||||
}
|
||||
visited_entry_points.push_back(inst);
|
||||
|
||||
has_mask_task_nv |= (execution_model == spv::ExecutionModel::TaskNV ||
|
||||
execution_model == spv::ExecutionModel::MeshNV);
|
||||
has_mask_task_ext |= (execution_model == spv::ExecutionModel::TaskEXT ||
|
||||
execution_model == spv::ExecutionModel::MeshEXT);
|
||||
}
|
||||
if (inst->opcode() == spv::Op::OpGraphEntryPointARM) {
|
||||
const auto graph = inst->GetOperandAs<uint32_t>(1);
|
||||
vstate->RegisterGraphEntryPoint(graph);
|
||||
}
|
||||
if (inst->opcode() == spv::Op::OpFunctionCall) {
|
||||
if (!vstate->in_function_body()) {
|
||||
return vstate->diag(SPV_ERROR_INVALID_LAYOUT, &instruction)
|
||||
@@ -299,6 +329,10 @@ spv_result_t ValidateBinaryUsingContextAndValidationState(
|
||||
return vstate->diag(SPV_ERROR_INVALID_LAYOUT, nullptr)
|
||||
<< "Missing OpFunctionEnd at end of module.";
|
||||
|
||||
if (vstate->graph_definition_region() != kGraphDefinitionOutside)
|
||||
return vstate->diag(SPV_ERROR_INVALID_LAYOUT, nullptr)
|
||||
<< "Missing OpGraphEndARM at end of module.";
|
||||
|
||||
if (vstate->HasCapability(spv::Capability::BindlessTextureNV) &&
|
||||
!vstate->has_samplerimage_variable_address_mode_specified())
|
||||
return vstate->diag(SPV_ERROR_INVALID_LAYOUT, nullptr)
|
||||
@@ -314,7 +348,7 @@ spv_result_t ValidateBinaryUsingContextAndValidationState(
|
||||
if (auto error = ValidateForwardDecls(*vstate)) return error;
|
||||
|
||||
// Calculate reachability after all the blocks are parsed, but early that it
|
||||
// can be relied on in subsequent pases.
|
||||
// can be relied on in subsequent passes.
|
||||
ReachabilityPass(*vstate);
|
||||
|
||||
// ID usage needs be handled in its own iteration of the instructions,
|
||||
@@ -368,6 +402,7 @@ spv_result_t ValidateBinaryUsingContextAndValidationState(
|
||||
if (auto error = MeshShadingPass(*vstate, &instruction)) return error;
|
||||
if (auto error = TensorLayoutPass(*vstate, &instruction)) return error;
|
||||
if (auto error = TensorPass(*vstate, &instruction)) return error;
|
||||
if (auto error = GraphPass(*vstate, &instruction)) return error;
|
||||
if (auto error = InvalidTypePass(*vstate, &instruction)) return error;
|
||||
}
|
||||
|
||||
@@ -377,6 +412,7 @@ spv_result_t ValidateBinaryUsingContextAndValidationState(
|
||||
if (auto error = ValidateAdjacency(*vstate)) return error;
|
||||
|
||||
if (auto error = ValidateEntryPoints(*vstate)) return error;
|
||||
if (auto error = ValidateGraphEntryPoints(*vstate)) return error;
|
||||
// CFG checks are performed after the binary has been parsed
|
||||
// and the CFGPass has collected information about the control flow
|
||||
if (auto error = PerformCfgChecks(*vstate)) return error;
|
||||
|
||||
7
3rdparty/spirv-tools/source/val/validate.h
vendored
7
3rdparty/spirv-tools/source/val/validate.h
vendored
@@ -195,8 +195,8 @@ spv_result_t NonUniformPass(ValidationState_t& _, const Instruction* inst);
|
||||
/// Validates correctness of debug instructions.
|
||||
spv_result_t DebugPass(ValidationState_t& _, const Instruction* inst);
|
||||
|
||||
// Validates that capability declarations use operands allowed in the current
|
||||
// context.
|
||||
/// Validates that capability declarations use operands allowed in the current
|
||||
/// context.
|
||||
spv_result_t CapabilityPass(ValidationState_t& _, const Instruction* inst);
|
||||
|
||||
/// Validates correctness of primitive instructions.
|
||||
@@ -226,6 +226,9 @@ spv_result_t MeshShadingPass(ValidationState_t& _, const Instruction* inst);
|
||||
/// Validates correctness of tensor instructions.
|
||||
spv_result_t TensorPass(ValidationState_t& _, const Instruction* inst);
|
||||
|
||||
/// Validates correctness of graph instructions.
|
||||
spv_result_t GraphPass(ValidationState_t& _, const Instruction* inst);
|
||||
|
||||
/// Validates correctness of certain special type instructions.
|
||||
spv_result_t InvalidTypePass(ValidationState_t& _, const Instruction* inst);
|
||||
|
||||
|
||||
@@ -333,6 +333,14 @@ spv_result_t ValidateDecorate(ValidationState_t& _, const Instruction* inst) {
|
||||
}
|
||||
|
||||
spv_result_t ValidateDecorateId(ValidationState_t& _, const Instruction* inst) {
|
||||
const auto target_id = inst->GetOperandAs<uint32_t>(0);
|
||||
const auto target = _.FindDef(target_id);
|
||||
if (target && spv::Op::OpDecorationGroup == target->opcode()) {
|
||||
return _.diag(SPV_ERROR_INVALID_ID, inst)
|
||||
<< "OpMemberDecorate Target <id> " << _.getIdName(target_id)
|
||||
<< " must not be an OpDecorationGroup instruction.";
|
||||
}
|
||||
|
||||
const auto decoration = inst->GetOperandAs<spv::Decoration>(1);
|
||||
if (!DecorationTakesIdParameters(decoration)) {
|
||||
return _.diag(SPV_ERROR_INVALID_ID, inst)
|
||||
@@ -340,6 +348,20 @@ spv_result_t ValidateDecorateId(ValidationState_t& _, const Instruction* inst) {
|
||||
"OpDecorateId";
|
||||
}
|
||||
|
||||
for (uint32_t i = 2; i < inst->operands().size(); ++i) {
|
||||
const auto param_id = inst->GetOperandAs<uint32_t>(i);
|
||||
const auto param = _.FindDef(param_id);
|
||||
|
||||
// Both target and param are elements of ordered_instructions we can
|
||||
// determine their relative positions in the SPIR-V module by comparing
|
||||
// pointers.
|
||||
if (target <= param) {
|
||||
return _.diag(SPV_ERROR_INVALID_ID, inst)
|
||||
<< "Parameter <ID> " << _.getIdName(param_id)
|
||||
<< " must appear earlier in the binary than the target";
|
||||
}
|
||||
}
|
||||
|
||||
// No member decorations take id parameters, so we don't bother checking if
|
||||
// we are using a member only decoration here.
|
||||
|
||||
@@ -388,8 +410,7 @@ spv_result_t ValidateDecorationGroup(ValidationState_t& _,
|
||||
if (use->opcode() != spv::Op::OpDecorate &&
|
||||
use->opcode() != spv::Op::OpGroupDecorate &&
|
||||
use->opcode() != spv::Op::OpGroupMemberDecorate &&
|
||||
use->opcode() != spv::Op::OpName &&
|
||||
use->opcode() != spv::Op::OpDecorateId && !use->IsNonSemantic()) {
|
||||
use->opcode() != spv::Op::OpName && !use->IsNonSemantic()) {
|
||||
return _.diag(SPV_ERROR_INVALID_ID, inst)
|
||||
<< "Result id of OpDecorationGroup can only "
|
||||
<< "be targeted by OpName, OpGroupDecorate, "
|
||||
|
||||
@@ -388,27 +388,6 @@ spv_result_t AtomicsPass(ValidationState_t& _, const Instruction* inst) {
|
||||
if (auto error = ValidateMemorySemantics(
|
||||
_, inst, unequal_semantics_index, memory_scope))
|
||||
return error;
|
||||
|
||||
// Volatile bits must match for equal and unequal semantics. Previous
|
||||
// checks guarantee they are 32-bit constants, but we need to recheck
|
||||
// whether they are evaluatable constants.
|
||||
bool is_int32 = false;
|
||||
bool is_equal_const = false;
|
||||
bool is_unequal_const = false;
|
||||
uint32_t equal_value = 0;
|
||||
uint32_t unequal_value = 0;
|
||||
std::tie(is_int32, is_equal_const, equal_value) = _.EvalInt32IfConst(
|
||||
inst->GetOperandAs<uint32_t>(equal_semantics_index));
|
||||
std::tie(is_int32, is_unequal_const, unequal_value) =
|
||||
_.EvalInt32IfConst(
|
||||
inst->GetOperandAs<uint32_t>(unequal_semantics_index));
|
||||
if (is_equal_const && is_unequal_const &&
|
||||
((equal_value & uint32_t(spv::MemorySemanticsMask::Volatile)) ^
|
||||
(unequal_value & uint32_t(spv::MemorySemanticsMask::Volatile)))) {
|
||||
return _.diag(SPV_ERROR_INVALID_ID, inst)
|
||||
<< "Volatile mask setting must match for Equal and Unequal "
|
||||
"memory semantics";
|
||||
}
|
||||
}
|
||||
|
||||
if (opcode == spv::Op::OpAtomicStore) {
|
||||
|
||||
@@ -45,10 +45,10 @@ spv_result_t BarriersPass(ValidationState_t& _, const Instruction* inst) {
|
||||
model != spv::ExecutionModel::MeshNV) {
|
||||
if (message) {
|
||||
*message =
|
||||
"OpControlBarrier requires one of the following "
|
||||
"Execution "
|
||||
"Models: TessellationControl, GLCompute, Kernel, "
|
||||
"MeshNV or TaskNV";
|
||||
"In SPIR-V 1.2 or earlier, OpControlBarrier requires "
|
||||
"one of the following "
|
||||
"Execution Models: TessellationControl, GLCompute, "
|
||||
"Kernel, MeshNV or TaskNV";
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -39,9 +39,11 @@ spv_result_t ValidateBaseType(ValidationState_t& _, const Instruction* inst,
|
||||
if (_.GetBitWidth(base_type) != 32 &&
|
||||
!_.options()->allow_vulkan_32_bit_bitwise) {
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||
<< _.VkErrorID(4781)
|
||||
<< _.VkErrorID(10824)
|
||||
<< "Expected 32-bit int type for Base operand: "
|
||||
<< spvOpcodeString(opcode);
|
||||
<< spvOpcodeString(opcode)
|
||||
<< _.MissingFeature("maintenance9 feature",
|
||||
"--allow-vulkan-32-bit-bitwise", false);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -17,21 +17,22 @@
|
||||
// Validates correctness of built-in variables.
|
||||
|
||||
#include <array>
|
||||
#include <cstdint>
|
||||
#include <functional>
|
||||
#include <list>
|
||||
#include <map>
|
||||
#include <set>
|
||||
#include <sstream>
|
||||
#include <stack>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#include "source/opcode.h"
|
||||
#include "source/spirv_target_env.h"
|
||||
#include "source/util/bitutils.h"
|
||||
#include "source/val/instruction.h"
|
||||
#include "source/val/validate.h"
|
||||
#include "source/val/validation_state.h"
|
||||
#include "spirv/unified1/spirv.hpp11"
|
||||
|
||||
namespace spvtools {
|
||||
namespace val {
|
||||
@@ -373,6 +374,18 @@ class BuiltInsValidator {
|
||||
spv_result_t ValidateMeshShadingEXTBuiltinsAtDefinition(
|
||||
const Decoration& decoration, const Instruction& inst);
|
||||
|
||||
// Used as a common method for validating MeshEXT builtins
|
||||
spv_result_t ValidateMeshBuiltinInterfaceRules(
|
||||
const Decoration& decoration, const Instruction& inst,
|
||||
spv::Op scalar_type, const Instruction& referenced_from_inst);
|
||||
spv_result_t ValidatePrimitiveShadingRateInterfaceRules(
|
||||
const Decoration& decoration, const Instruction& inst,
|
||||
const Instruction& referenced_from_inst);
|
||||
// Builtin that needs check incase **not** used with MeshEXT
|
||||
spv_result_t ValidateNonMeshInterfaceRules(
|
||||
const Decoration& decoration, const Instruction& inst,
|
||||
const Instruction& referenced_from_inst);
|
||||
|
||||
// The following section contains functions which are called when id defined
|
||||
// by |referenced_inst| is
|
||||
// 1. referenced by |referenced_from_inst|
|
||||
@@ -590,8 +603,9 @@ class BuiltInsValidator {
|
||||
spv_result_t ValidateBool(
|
||||
const Decoration& decoration, const Instruction& inst,
|
||||
const std::function<spv_result_t(const std::string& message)>& diag);
|
||||
spv_result_t ValidateBlockBoolOrArrayedBool(
|
||||
spv_result_t ValidateBlockTypeOrArrayedType(
|
||||
const Decoration& decoration, const Instruction& inst,
|
||||
bool& present_in_block, spv::Op expected_scalar_type,
|
||||
const std::function<spv_result_t(const std::string& message)>& diag);
|
||||
spv_result_t ValidateI(
|
||||
const Decoration& decoration, const Instruction& inst,
|
||||
@@ -675,20 +689,50 @@ class BuiltInsValidator {
|
||||
// UniformConstant".
|
||||
std::string GetStorageClassDesc(const Instruction& inst) const;
|
||||
|
||||
uint64_t GetArrayLength(uint32_t interface_var_id);
|
||||
|
||||
// Updates inner working of the class. Is called sequentially for every
|
||||
// instruction.
|
||||
void Update(const Instruction& inst);
|
||||
|
||||
// Check if "inst" is an interface variable
|
||||
// or type of a interface varibale of any mesh entry point
|
||||
bool isMeshInterfaceVar(const Instruction& inst) {
|
||||
auto getUnderlyingTypeId = [&](const Instruction* ifxVar) {
|
||||
auto pointerTypeInst = _.FindDef(ifxVar->type_id());
|
||||
auto typeInst = _.FindDef(pointerTypeInst->GetOperandAs<uint32_t>(2));
|
||||
while (typeInst->opcode() == spv::Op::OpTypeArray) {
|
||||
typeInst = _.FindDef(typeInst->GetOperandAs<uint32_t>(1));
|
||||
bool IsBulitinInEntryPoint(const Instruction& inst, uint32_t entry_point) {
|
||||
auto get_underlying_type_id = [&](const Instruction* ifx_var) {
|
||||
auto pointer_type_inst = _.FindDef(ifx_var->type_id());
|
||||
auto type_inst = _.FindDef(pointer_type_inst->GetOperandAs<uint32_t>(2));
|
||||
while (type_inst->opcode() == spv::Op::OpTypeArray) {
|
||||
type_inst = _.FindDef(type_inst->GetOperandAs<uint32_t>(1));
|
||||
};
|
||||
return typeInst->id();
|
||||
return type_inst->id();
|
||||
};
|
||||
|
||||
for (const auto& desc : _.entry_point_descriptions(entry_point)) {
|
||||
for (auto interface : desc.interfaces) {
|
||||
if (inst.opcode() == spv::Op::OpTypeStruct) {
|
||||
auto varInst = _.FindDef(interface);
|
||||
if (inst.id() == get_underlying_type_id(varInst)) {
|
||||
return true;
|
||||
}
|
||||
} else if (inst.id() == interface) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// Check if "inst" is an interface variable or type of a interface varibale
|
||||
// of any mesh entry point. Populate entry_point_interface_id with all
|
||||
// entry points and interface variables that refer to the "inst"
|
||||
bool IsMeshInterfaceVar(
|
||||
const Instruction& inst,
|
||||
std::map<uint32_t, uint32_t>& entry_point_interface_id) {
|
||||
auto get_underlying_type_id = [&](const Instruction* ifx_var) {
|
||||
auto pointer_type_inst = _.FindDef(ifx_var->type_id());
|
||||
auto type_inst = _.FindDef(pointer_type_inst->GetOperandAs<uint32_t>(2));
|
||||
while (type_inst->opcode() == spv::Op::OpTypeArray) {
|
||||
type_inst = _.FindDef(type_inst->GetOperandAs<uint32_t>(1));
|
||||
};
|
||||
return type_inst->id();
|
||||
};
|
||||
|
||||
for (const uint32_t entry_point : _.entry_points()) {
|
||||
@@ -699,15 +743,19 @@ class BuiltInsValidator {
|
||||
for (auto interface : desc.interfaces) {
|
||||
if (inst.opcode() == spv::Op::OpTypeStruct) {
|
||||
auto varInst = _.FindDef(interface);
|
||||
if (inst.id() == getUnderlyingTypeId(varInst)) return true;
|
||||
if (inst.id() == get_underlying_type_id(varInst)) {
|
||||
entry_point_interface_id[entry_point] = interface;
|
||||
break;
|
||||
}
|
||||
} else if (inst.id() == interface) {
|
||||
return true;
|
||||
entry_point_interface_id[entry_point] = interface;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return false;
|
||||
return !entry_point_interface_id.empty();
|
||||
}
|
||||
|
||||
ValidationState_t& _;
|
||||
@@ -730,6 +778,10 @@ class BuiltInsValidator {
|
||||
|
||||
// Execution models with which the current function can be called.
|
||||
std::set<spv::ExecutionModel> execution_models_;
|
||||
|
||||
// For Builtin that can only be declared once in an entry point, keep track if
|
||||
// the entry point has it already
|
||||
std::set<uint32_t> cull_primitive_entry_points_;
|
||||
};
|
||||
|
||||
void BuiltInsValidator::Update(const Instruction& inst) {
|
||||
@@ -807,6 +859,29 @@ std::string BuiltInsValidator::GetStorageClassDesc(
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
uint64_t BuiltInsValidator::GetArrayLength(uint32_t interface_var_id) {
|
||||
uint32_t underlying_type;
|
||||
spv::StorageClass storage_class;
|
||||
uint64_t array_len = -1;
|
||||
const Instruction* inst = _.FindDef(interface_var_id);
|
||||
if (inst->opcode() != spv::Op::OpVariable) {
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (!_.GetPointerTypeInfo(inst->type_id(), &underlying_type,
|
||||
&storage_class)) {
|
||||
return 0;
|
||||
}
|
||||
if (_.GetIdOpcode(underlying_type) == spv::Op::OpTypeArray) {
|
||||
// Get the array length
|
||||
const auto length_id = _.FindDef(underlying_type)->word(3u);
|
||||
if (!_.EvalConstantValUint64(length_id, &array_len)) {
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
return array_len;
|
||||
}
|
||||
|
||||
spv_result_t BuiltInsValidator::ValidateBool(
|
||||
const Decoration& decoration, const Instruction& inst,
|
||||
const std::function<spv_result_t(const std::string& message)>& diag) {
|
||||
@@ -823,25 +898,50 @@ spv_result_t BuiltInsValidator::ValidateBool(
|
||||
return SPV_SUCCESS;
|
||||
}
|
||||
|
||||
spv_result_t BuiltInsValidator::ValidateBlockBoolOrArrayedBool(
|
||||
const Decoration& decoration, const Instruction& inst,
|
||||
spv_result_t BuiltInsValidator::ValidateBlockTypeOrArrayedType(
|
||||
const Decoration& decoration, const Instruction& inst, bool& isBlock,
|
||||
spv::Op expected_scalar_type,
|
||||
const std::function<spv_result_t(const std::string& message)>& diag) {
|
||||
uint32_t underlying_type = 0;
|
||||
int64_t array_len = -1;
|
||||
isBlock = true;
|
||||
if (spv_result_t error =
|
||||
GetUnderlyingType(_, decoration, inst, &underlying_type)) {
|
||||
return error;
|
||||
}
|
||||
// Strip the array, if present.
|
||||
if (_.GetIdOpcode(underlying_type) == spv::Op::OpTypeArray) {
|
||||
// Get the array length
|
||||
const auto length_id = _.FindDef(underlying_type)->word(3u);
|
||||
if (!_.EvalConstantValInt64(length_id, &array_len)) {
|
||||
return diag(GetDefinitionDesc(decoration, inst) +
|
||||
" Failed to find the array length.");
|
||||
}
|
||||
underlying_type = _.FindDef(underlying_type)->word(2u);
|
||||
isBlock = false;
|
||||
} else if (!_.HasDecoration(inst.id(), spv::Decoration::Block)) {
|
||||
// If not in array, and bool is in a struct, must be in a Block struct
|
||||
return diag(GetDefinitionDesc(decoration, inst) +
|
||||
" Scalar boolean must be in a Block.");
|
||||
}
|
||||
|
||||
if (!_.IsBoolScalarType(underlying_type)) {
|
||||
return diag(GetDefinitionDesc(decoration, inst) + " is not a bool scalar.");
|
||||
switch (expected_scalar_type) {
|
||||
case spv::Op::OpTypeBool:
|
||||
if (!_.IsBoolScalarType(underlying_type)) {
|
||||
return diag(GetDefinitionDesc(decoration, inst) +
|
||||
" is not a bool scalar.");
|
||||
}
|
||||
break;
|
||||
case spv::Op::OpTypeInt:
|
||||
if (!_.IsIntScalarType(underlying_type)) {
|
||||
return diag(GetDefinitionDesc(decoration, inst) +
|
||||
" is not an integer scalar.");
|
||||
}
|
||||
break;
|
||||
default:
|
||||
assert(0 && "Unhandled scalar type");
|
||||
return diag(GetDefinitionDesc(decoration, inst) +
|
||||
" is not a recognized scalar type.");
|
||||
}
|
||||
|
||||
return SPV_SUCCESS;
|
||||
@@ -2188,49 +2288,6 @@ spv_result_t BuiltInsValidator::ValidatePositionAtReference(
|
||||
|
||||
spv_result_t BuiltInsValidator::ValidatePrimitiveIdAtDefinition(
|
||||
const Decoration& decoration, const Instruction& inst) {
|
||||
if (spvIsVulkanEnv(_.context()->target_env)) {
|
||||
// PrimitiveId can be a per-primitive variable for mesh shader stage.
|
||||
// In such cases variable will have an array of 32-bit integers.
|
||||
if (decoration.struct_member_index() != Decoration::kInvalidMember) {
|
||||
// This must be a 32-bit int scalar.
|
||||
if (spv_result_t error = ValidateI32(
|
||||
decoration, inst,
|
||||
[this, &inst](const std::string& message) -> spv_result_t {
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, &inst)
|
||||
<< _.VkErrorID(4337)
|
||||
<< "According to the Vulkan spec BuiltIn PrimitiveId "
|
||||
"variable needs to be a 32-bit int scalar. "
|
||||
<< message;
|
||||
})) {
|
||||
return error;
|
||||
}
|
||||
} else {
|
||||
if (spv_result_t error = ValidateOptionalArrayedI32(
|
||||
decoration, inst,
|
||||
[this, &inst](const std::string& message) -> spv_result_t {
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, &inst)
|
||||
<< _.VkErrorID(4337)
|
||||
<< "According to the Vulkan spec BuiltIn PrimitiveId "
|
||||
"variable needs to be a 32-bit int scalar. "
|
||||
<< message;
|
||||
})) {
|
||||
return error;
|
||||
}
|
||||
}
|
||||
|
||||
if (_.HasCapability(spv::Capability::MeshShadingEXT)) {
|
||||
if (isMeshInterfaceVar(inst) &&
|
||||
!_.HasDecoration(inst.id(), spv::Decoration::PerPrimitiveEXT)) {
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, &inst)
|
||||
<< _.VkErrorID(7040)
|
||||
<< "According to the Vulkan spec the variable decorated with "
|
||||
"Builtin PrimitiveId within the MeshEXT Execution Model must "
|
||||
"also be decorated with the PerPrimitiveEXT decoration. ";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Seed at reference checks with this built-in.
|
||||
return ValidatePrimitiveIdAtReference(decoration, inst, inst, inst);
|
||||
}
|
||||
|
||||
@@ -2297,6 +2354,27 @@ spv_result_t BuiltInsValidator::ValidatePrimitiveIdAtReference(
|
||||
referenced_from_inst, std::placeholders::_1));
|
||||
}
|
||||
|
||||
if (!_.HasCapability(spv::Capability::MeshShadingEXT) &&
|
||||
!_.HasCapability(spv::Capability::MeshShadingNV) &&
|
||||
!_.HasCapability(spv::Capability::Geometry) &&
|
||||
!_.HasCapability(spv::Capability::Tessellation)) {
|
||||
id_to_at_reference_checks_[referenced_from_inst.id()].push_back(std::bind(
|
||||
&BuiltInsValidator::ValidateNotCalledWithExecutionModel, this, 4333,
|
||||
"Vulkan spec doesn't allow BuiltIn PrimitiveId to be used for "
|
||||
"variables in the Fragment execution model unless it declares "
|
||||
"Geometry, Tessellation, or MeshShader capabilities.",
|
||||
spv::ExecutionModel::Fragment, decoration, built_in_inst,
|
||||
referenced_from_inst, std::placeholders::_1));
|
||||
}
|
||||
|
||||
id_to_at_reference_checks_[referenced_from_inst.id()].push_back(std::bind(
|
||||
&BuiltInsValidator::ValidateMeshBuiltinInterfaceRules, this, decoration,
|
||||
built_in_inst, spv::Op::OpTypeInt, std::placeholders::_1));
|
||||
|
||||
id_to_at_reference_checks_[referenced_from_inst.id()].push_back(
|
||||
std::bind(&BuiltInsValidator::ValidateNonMeshInterfaceRules, this,
|
||||
decoration, built_in_inst, std::placeholders::_1));
|
||||
|
||||
for (const spv::ExecutionModel execution_model : execution_models_) {
|
||||
switch (execution_model) {
|
||||
case spv::ExecutionModel::Fragment:
|
||||
@@ -2593,6 +2671,13 @@ spv_result_t BuiltInsValidator::ValidateTessLevelOuterAtDefinition(
|
||||
})) {
|
||||
return error;
|
||||
}
|
||||
|
||||
if (!_.HasDecoration(inst.id(), spv::Decoration::Patch)) {
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, &inst)
|
||||
<< _.VkErrorID(10880)
|
||||
<< "BuiltIn TessLevelOuter variable needs to also have a Patch "
|
||||
"decoration.";
|
||||
}
|
||||
}
|
||||
|
||||
// Seed at reference checks with this built-in.
|
||||
@@ -2607,13 +2692,20 @@ spv_result_t BuiltInsValidator::ValidateTessLevelInnerAtDefinition(
|
||||
[this, &inst](const std::string& message) -> spv_result_t {
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, &inst)
|
||||
<< _.VkErrorID(4397)
|
||||
<< "According to the Vulkan spec BuiltIn TessLevelOuter "
|
||||
<< "According to the Vulkan spec BuiltIn TessLevelInner "
|
||||
"variable needs to be a 2-component 32-bit float "
|
||||
"array. "
|
||||
<< message;
|
||||
})) {
|
||||
return error;
|
||||
}
|
||||
|
||||
if (!_.HasDecoration(inst.id(), spv::Decoration::Patch)) {
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, &inst)
|
||||
<< _.VkErrorID(10880)
|
||||
<< "BuiltIn TessLevelInner variable needs to also have a Patch "
|
||||
"decoration.";
|
||||
}
|
||||
}
|
||||
|
||||
// Seed at reference checks with this built-in.
|
||||
@@ -2796,67 +2888,180 @@ spv_result_t BuiltInsValidator::ValidateVertexIndexAtReference(
|
||||
return SPV_SUCCESS;
|
||||
}
|
||||
|
||||
spv_result_t BuiltInsValidator::ValidateLayerOrViewportIndexAtDefinition(
|
||||
const Decoration& decoration, const Instruction& inst) {
|
||||
if (spvIsVulkanEnv(_.context()->target_env)) {
|
||||
// This can be a per-primitive variable for mesh shader stage.
|
||||
// In such cases variable will have an array of 32-bit integers.
|
||||
if (decoration.struct_member_index() != Decoration::kInvalidMember) {
|
||||
// This must be a 32-bit int scalar.
|
||||
typedef struct {
|
||||
uint32_t array_type;
|
||||
uint32_t array_size;
|
||||
uint32_t block_array_size;
|
||||
uint32_t perprim_deco;
|
||||
} MeshBuiltinVUIDs;
|
||||
|
||||
spv_result_t BuiltInsValidator::ValidateMeshBuiltinInterfaceRules(
|
||||
const Decoration& decoration, const Instruction& inst, spv::Op scalar_type,
|
||||
const Instruction& referenced_from_inst) {
|
||||
if (function_id_) {
|
||||
if (execution_models_.count(spv::ExecutionModel::MeshEXT)) {
|
||||
bool is_block = false;
|
||||
const spv::BuiltIn builtin = decoration.builtin();
|
||||
|
||||
static const std::unordered_map<spv::BuiltIn, MeshBuiltinVUIDs>
|
||||
mesh_vuid_map = {{
|
||||
{spv::BuiltIn::CullPrimitiveEXT, {7036, 10589, 10590, 7038}},
|
||||
{spv::BuiltIn::PrimitiveId, {10595, 10596, 10597, 7040}},
|
||||
{spv::BuiltIn::Layer, {10592, 10593, 10594, 7039}},
|
||||
{spv::BuiltIn::ViewportIndex, {10601, 10602, 10603, 7060}},
|
||||
{spv::BuiltIn::PrimitiveShadingRateKHR,
|
||||
{10598, 10599, 10600, 7059}},
|
||||
}};
|
||||
const MeshBuiltinVUIDs& vuids = mesh_vuid_map.at(builtin);
|
||||
if (spv_result_t error = ValidateBlockTypeOrArrayedType(
|
||||
decoration, inst, is_block, scalar_type,
|
||||
[this, &inst, &builtin, &scalar_type,
|
||||
&vuids](const std::string& message) -> spv_result_t {
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, &inst)
|
||||
<< _.VkErrorID(vuids.array_type)
|
||||
<< "According to the Vulkan specspec BuiltIn "
|
||||
<< _.grammar().lookupOperandName(
|
||||
SPV_OPERAND_TYPE_BUILT_IN, (uint32_t)builtin)
|
||||
<< " variable needs to be a either a "
|
||||
<< spvOpcodeString(scalar_type)
|
||||
<< " or an "
|
||||
"array of "
|
||||
<< spvOpcodeString(scalar_type) << ". " << message;
|
||||
})) {
|
||||
return error;
|
||||
}
|
||||
|
||||
if (!_.HasDecoration(inst.id(), spv::Decoration::PerPrimitiveEXT)) {
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, &inst)
|
||||
<< _.VkErrorID(vuids.perprim_deco)
|
||||
<< "According to the Vulkan spec the variable decorated with "
|
||||
"Builtin "
|
||||
<< _.grammar().lookupOperandName(SPV_OPERAND_TYPE_BUILT_IN,
|
||||
(uint32_t)builtin)
|
||||
<< " within the MeshEXT Execution Model must also be "
|
||||
<< "decorated with the PerPrimitiveEXT decoration. ";
|
||||
}
|
||||
|
||||
// These builtin have the ability to be an array with MeshEXT
|
||||
// When an array, we need to make sure the array size lines up
|
||||
std::map<uint32_t, uint32_t> entry_interface_id_map;
|
||||
bool found = IsMeshInterfaceVar(inst, entry_interface_id_map);
|
||||
if (found) {
|
||||
for (const auto& id : entry_interface_id_map) {
|
||||
uint32_t entry_point_id = id.first;
|
||||
uint32_t interface_var_id = id.second;
|
||||
|
||||
const uint64_t interface_size = GetArrayLength(interface_var_id);
|
||||
const uint32_t output_prim_size =
|
||||
_.GetOutputPrimitivesEXT(entry_point_id);
|
||||
if (interface_size != output_prim_size) {
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, &inst)
|
||||
<< _.VkErrorID(is_block ? vuids.block_array_size
|
||||
: vuids.array_size)
|
||||
<< " The size of the array decorated with "
|
||||
<< _.grammar().lookupOperandName(SPV_OPERAND_TYPE_BUILT_IN,
|
||||
(uint32_t)builtin)
|
||||
<< " (" << interface_size
|
||||
<< ") must match the value specified by OutputPrimitivesEXT "
|
||||
"("
|
||||
<< output_prim_size << "). ";
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Propagate this rule to all dependant ids in the global scope.
|
||||
id_to_at_reference_checks_[referenced_from_inst.id()].push_back(
|
||||
std::bind(&BuiltInsValidator::ValidateMeshBuiltinInterfaceRules, this,
|
||||
decoration, inst, scalar_type, std::placeholders::_1));
|
||||
}
|
||||
return SPV_SUCCESS;
|
||||
}
|
||||
|
||||
spv_result_t BuiltInsValidator::ValidatePrimitiveShadingRateInterfaceRules(
|
||||
const Decoration& decoration, const Instruction& inst,
|
||||
const Instruction& referenced_from_inst) {
|
||||
if (function_id_) {
|
||||
if (!execution_models_.count(spv::ExecutionModel::MeshEXT)) {
|
||||
if (spv_result_t error = ValidateI32(
|
||||
decoration, inst,
|
||||
[this, &decoration,
|
||||
&inst](const std::string& message) -> spv_result_t {
|
||||
uint32_t vuid =
|
||||
(decoration.builtin() == spv::BuiltIn::Layer) ? 4276 : 4408;
|
||||
[this, &inst,
|
||||
&decoration](const std::string& message) -> spv_result_t {
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, &inst)
|
||||
<< _.VkErrorID(vuid)
|
||||
<< _.VkErrorID(4486)
|
||||
<< "According to the Vulkan spec BuiltIn "
|
||||
<< _.grammar().lookupOperandName(
|
||||
SPV_OPERAND_TYPE_BUILT_IN,
|
||||
(uint32_t)decoration.builtin())
|
||||
<< "variable needs to be a 32-bit int scalar. "
|
||||
<< message;
|
||||
})) {
|
||||
return error;
|
||||
}
|
||||
} else {
|
||||
if (spv_result_t error = ValidateOptionalArrayedI32(
|
||||
decoration, inst,
|
||||
[this, &decoration,
|
||||
&inst](const std::string& message) -> spv_result_t {
|
||||
uint32_t vuid =
|
||||
(decoration.builtin() == spv::BuiltIn::Layer) ? 4276 : 4408;
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, &inst)
|
||||
<< _.VkErrorID(vuid)
|
||||
<< "According to the Vulkan spec BuiltIn "
|
||||
<< _.grammar().lookupOperandName(
|
||||
SPV_OPERAND_TYPE_BUILT_IN,
|
||||
(uint32_t)decoration.builtin())
|
||||
<< "variable needs to be a 32-bit int scalar. "
|
||||
<< " variable needs to be a 32-bit int scalar. "
|
||||
<< message;
|
||||
})) {
|
||||
return error;
|
||||
}
|
||||
}
|
||||
|
||||
if (isMeshInterfaceVar(inst) &&
|
||||
_.HasCapability(spv::Capability::MeshShadingEXT) &&
|
||||
!_.HasDecoration(inst.id(), spv::Decoration::PerPrimitiveEXT)) {
|
||||
const spv::BuiltIn label = spv::BuiltIn(decoration.params()[0]);
|
||||
uint32_t vkerrid = (label == spv::BuiltIn::Layer) ? 7039 : 7060;
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, &inst)
|
||||
<< _.VkErrorID(vkerrid)
|
||||
<< "According to the Vulkan spec the variable decorated with "
|
||||
"Builtin "
|
||||
<< _.grammar().lookupOperandName(SPV_OPERAND_TYPE_BUILT_IN,
|
||||
decoration.params()[0])
|
||||
<< " within the MeshEXT Execution Model must also be decorated "
|
||||
"with the PerPrimitiveEXT decoration. ";
|
||||
}
|
||||
} else {
|
||||
// Propagate this rule to all dependant ids in the global scope.
|
||||
id_to_at_reference_checks_[referenced_from_inst.id()].push_back(std::bind(
|
||||
&BuiltInsValidator::ValidatePrimitiveShadingRateInterfaceRules, this,
|
||||
decoration, inst, std::placeholders::_1));
|
||||
}
|
||||
return SPV_SUCCESS;
|
||||
}
|
||||
|
||||
// Seed at reference checks with this built-in.
|
||||
// For Layer, ViewportIndex, and PrimitiveId
|
||||
spv_result_t BuiltInsValidator::ValidateNonMeshInterfaceRules(
|
||||
const Decoration& decoration, const Instruction& inst,
|
||||
const Instruction& referenced_from_inst) {
|
||||
if (function_id_) {
|
||||
// This can be a per-primitive variable for NV mesh shader stage.
|
||||
// In such cases variable will have an array of 32-bit integers.
|
||||
if (!execution_models_.count(spv::ExecutionModel::MeshEXT)) {
|
||||
const spv::BuiltIn builtin = decoration.builtin();
|
||||
const uint32_t vuid = (builtin == spv::BuiltIn::Layer) ? 4276
|
||||
: (builtin == spv::BuiltIn::ViewportIndex) ? 4408
|
||||
: 4337;
|
||||
if (decoration.struct_member_index() != Decoration::kInvalidMember) {
|
||||
if (spv_result_t error = ValidateI32(
|
||||
decoration, inst,
|
||||
[this, &vuid, builtin,
|
||||
&inst](const std::string& message) -> spv_result_t {
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, &inst)
|
||||
<< _.VkErrorID(vuid)
|
||||
<< "According to the Vulkan spec BuiltIn "
|
||||
<< _.grammar().lookupOperandName(
|
||||
SPV_OPERAND_TYPE_BUILT_IN, (uint32_t)builtin)
|
||||
<< "variable needs to be a 32-bit int scalar. "
|
||||
<< message;
|
||||
})) {
|
||||
return error;
|
||||
}
|
||||
} else if (spv_result_t error = ValidateOptionalArrayedI32(
|
||||
decoration, inst,
|
||||
[this, &vuid, builtin,
|
||||
&inst](const std::string& message) -> spv_result_t {
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, &inst)
|
||||
<< _.VkErrorID(vuid)
|
||||
<< "According to the Vulkan spec BuiltIn "
|
||||
<< _.grammar().lookupOperandName(
|
||||
SPV_OPERAND_TYPE_BUILT_IN,
|
||||
(uint32_t)builtin)
|
||||
<< "variable needs to be a 32-bit int scalar. "
|
||||
<< message;
|
||||
})) {
|
||||
return error;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Propagate this rule to all dependant ids in the global scope.
|
||||
id_to_at_reference_checks_[referenced_from_inst.id()].push_back(
|
||||
std::bind(&BuiltInsValidator::ValidateNonMeshInterfaceRules, this,
|
||||
decoration, inst, std::placeholders::_1));
|
||||
}
|
||||
return SPV_SUCCESS;
|
||||
}
|
||||
|
||||
spv_result_t BuiltInsValidator::ValidateLayerOrViewportIndexAtDefinition(
|
||||
const Decoration& decoration, const Instruction& inst) {
|
||||
return ValidateLayerOrViewportIndexAtReference(decoration, inst, inst, inst);
|
||||
}
|
||||
|
||||
@@ -2914,6 +3119,14 @@ spv_result_t BuiltInsValidator::ValidateLayerOrViewportIndexAtReference(
|
||||
referenced_from_inst, std::placeholders::_1));
|
||||
}
|
||||
|
||||
id_to_at_reference_checks_[referenced_from_inst.id()].push_back(std::bind(
|
||||
&BuiltInsValidator::ValidateMeshBuiltinInterfaceRules, this, decoration,
|
||||
built_in_inst, spv::Op::OpTypeInt, std::placeholders::_1));
|
||||
|
||||
id_to_at_reference_checks_[referenced_from_inst.id()].push_back(
|
||||
std::bind(&BuiltInsValidator::ValidateNonMeshInterfaceRules, this,
|
||||
decoration, built_in_inst, std::placeholders::_1));
|
||||
|
||||
for (const spv::ExecutionModel execution_model : execution_models_) {
|
||||
switch (execution_model) {
|
||||
case spv::ExecutionModel::Geometry:
|
||||
@@ -3338,12 +3551,47 @@ spv_result_t BuiltInsValidator::ValidateWorkgroupSizeAtDefinition(
|
||||
bool static_x = _.EvalConstantValUint64(inst.word(3), &x_size);
|
||||
bool static_y = _.EvalConstantValUint64(inst.word(4), &y_size);
|
||||
bool static_z = _.EvalConstantValUint64(inst.word(5), &z_size);
|
||||
if (static_x && static_y && static_z &&
|
||||
((x_size * y_size * z_size) == 0)) {
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, &inst)
|
||||
<< "WorkgroupSize decorations must not have a static "
|
||||
"product of zero (X = "
|
||||
<< x_size << ", Y = " << y_size << ", Z = " << z_size << ").";
|
||||
if (static_x && static_y && static_z) {
|
||||
const uint64_t product_size = x_size * y_size * z_size;
|
||||
if (product_size == 0) {
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, &inst)
|
||||
<< "WorkgroupSize decorations must not have a static "
|
||||
"product of zero (X = "
|
||||
<< x_size << ", Y = " << y_size << ", Z = " << z_size << ").";
|
||||
}
|
||||
|
||||
// If there is a known static workgroup size, all entrypoints with
|
||||
// explicit derivative execution modes can be validated. These are only
|
||||
// found in execution models that support explicit workgroup sizes
|
||||
for (const uint32_t entry_point : _.entry_points()) {
|
||||
const auto* modes = _.GetExecutionModes(entry_point);
|
||||
if (!modes) continue;
|
||||
if (modes->count(spv::ExecutionMode::DerivativeGroupQuadsKHR)) {
|
||||
if (x_size % 2 != 0 || y_size % 2 != 0) {
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, &inst)
|
||||
<< _.VkErrorID(10151)
|
||||
<< "WorkgroupSize decorations has a static dimensions of "
|
||||
"(X = "
|
||||
<< x_size << ", Y = " << y_size << ") but Entry Point id "
|
||||
<< entry_point
|
||||
<< " has an DerivativeGroupQuadsKHR execution mode, so "
|
||||
"both dimensions must be a multiple of 2";
|
||||
}
|
||||
}
|
||||
if (modes->count(spv::ExecutionMode::DerivativeGroupLinearKHR)) {
|
||||
if (product_size % 4 != 0) {
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, &inst)
|
||||
<< _.VkErrorID(10152)
|
||||
<< "WorkgroupSize decorations has a static dimensions of "
|
||||
"(X = "
|
||||
<< x_size << ", Y = " << y_size << ", Z = " << z_size
|
||||
<< ") but Entry Point id " << entry_point
|
||||
<< " has an DerivativeGroupLinearKHR execution mode, so "
|
||||
"the product ("
|
||||
<< product_size << ") must be a multiple of 4";
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -3986,34 +4234,6 @@ spv_result_t BuiltInsValidator::ValidateNVSMOrARMCoreBuiltinsAtReference(
|
||||
|
||||
spv_result_t BuiltInsValidator::ValidatePrimitiveShadingRateAtDefinition(
|
||||
const Decoration& decoration, const Instruction& inst) {
|
||||
if (spvIsVulkanEnv(_.context()->target_env)) {
|
||||
if (spv_result_t error = ValidateI32(
|
||||
decoration, inst,
|
||||
[this, &inst,
|
||||
&decoration](const std::string& message) -> spv_result_t {
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, &inst)
|
||||
<< _.VkErrorID(4486)
|
||||
<< "According to the Vulkan spec BuiltIn "
|
||||
<< _.grammar().lookupOperandName(
|
||||
SPV_OPERAND_TYPE_BUILT_IN,
|
||||
(uint32_t)decoration.builtin())
|
||||
<< " variable needs to be a 32-bit int scalar. "
|
||||
<< message;
|
||||
})) {
|
||||
return error;
|
||||
}
|
||||
if (isMeshInterfaceVar(inst) &&
|
||||
_.HasCapability(spv::Capability::MeshShadingEXT) &&
|
||||
!_.HasDecoration(inst.id(), spv::Decoration::PerPrimitiveEXT)) {
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, &inst)
|
||||
<< _.VkErrorID(7059)
|
||||
<< "The variable decorated with PrimitiveShadingRateKHR "
|
||||
"within the MeshEXT Execution Model must also be "
|
||||
"decorated with the PerPrimitiveEXT decoration";
|
||||
}
|
||||
}
|
||||
|
||||
// Seed at reference checks with this built-in.
|
||||
return ValidatePrimitiveShadingRateAtReference(decoration, inst, inst, inst);
|
||||
}
|
||||
|
||||
@@ -4035,6 +4255,14 @@ spv_result_t BuiltInsValidator::ValidatePrimitiveShadingRateAtReference(
|
||||
<< " " << GetStorageClassDesc(referenced_from_inst);
|
||||
}
|
||||
|
||||
id_to_at_reference_checks_[referenced_from_inst.id()].push_back(std::bind(
|
||||
&BuiltInsValidator::ValidateMeshBuiltinInterfaceRules, this, decoration,
|
||||
built_in_inst, spv::Op::OpTypeInt, std::placeholders::_1));
|
||||
|
||||
id_to_at_reference_checks_[referenced_from_inst.id()].push_back(std::bind(
|
||||
&BuiltInsValidator::ValidatePrimitiveShadingRateInterfaceRules, this,
|
||||
decoration, built_in_inst, std::placeholders::_1));
|
||||
|
||||
for (const spv::ExecutionModel execution_model : execution_models_) {
|
||||
switch (execution_model) {
|
||||
case spv::ExecutionModel::Vertex:
|
||||
@@ -4365,48 +4593,61 @@ spv_result_t BuiltInsValidator::ValidateMeshShadingEXTBuiltinsAtDefinition(
|
||||
return error;
|
||||
}
|
||||
break;
|
||||
case spv::BuiltIn::CullPrimitiveEXT:
|
||||
if (spv_result_t error = ValidateBlockBoolOrArrayedBool(
|
||||
decoration, inst,
|
||||
[this, &inst, &decoration,
|
||||
&vuid](const std::string& message) -> spv_result_t {
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, &inst)
|
||||
<< _.VkErrorID(vuid) << "According to the "
|
||||
<< spvLogStringForEnv(_.context()->target_env)
|
||||
<< " spec BuiltIn "
|
||||
<< _.grammar().lookupOperandName(
|
||||
SPV_OPERAND_TYPE_BUILT_IN,
|
||||
(uint32_t)decoration.builtin())
|
||||
<< " variable needs to be a either a boolean or an "
|
||||
"array of booleans."
|
||||
<< message;
|
||||
})) {
|
||||
case spv::BuiltIn::CullPrimitiveEXT: {
|
||||
// We know this only allowed for Mesh Execution Model
|
||||
if (spv_result_t error = ValidateMeshBuiltinInterfaceRules(
|
||||
decoration, inst, spv::Op::OpTypeBool, inst)) {
|
||||
return error;
|
||||
}
|
||||
if (!_.HasDecoration(inst.id(), spv::Decoration::PerPrimitiveEXT)) {
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, &inst)
|
||||
<< _.VkErrorID(7038)
|
||||
<< "The variable decorated with CullPrimitiveEXT within the "
|
||||
"MeshEXT Execution Model must also be decorated with the "
|
||||
"PerPrimitiveEXT decoration ";
|
||||
|
||||
for (const uint32_t entry_point : _.entry_points()) {
|
||||
auto* models = _.GetExecutionModels(entry_point);
|
||||
if (models->find(spv::ExecutionModel::MeshEXT) == models->end() &&
|
||||
models->find(spv::ExecutionModel::MeshNV) == models->end()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (IsBulitinInEntryPoint(inst, entry_point)) {
|
||||
if (cull_primitive_entry_points_.find(entry_point) !=
|
||||
cull_primitive_entry_points_.end()) {
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, &inst)
|
||||
<< _.VkErrorID(10591)
|
||||
<< "There must be only one declaration of the "
|
||||
"CullPrimitiveEXT associated in entry point's "
|
||||
"interface. "
|
||||
<< GetIdDesc(*_.FindDef(entry_point));
|
||||
} else {
|
||||
cull_primitive_entry_points_.insert(entry_point);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
break;
|
||||
}
|
||||
default:
|
||||
assert(0 && "Unexpected mesh EXT builtin");
|
||||
}
|
||||
for (const uint32_t entry_point : _.entry_points()) {
|
||||
// execution modes and builtin are both global, so only check these
|
||||
// buildit definitions if we know the entrypoint is Mesh
|
||||
auto* models = _.GetExecutionModels(entry_point);
|
||||
if (models->find(spv::ExecutionModel::MeshEXT) == models->end() &&
|
||||
models->find(spv::ExecutionModel::MeshNV) == models->end()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const auto* modes = _.GetExecutionModes(entry_point);
|
||||
uint64_t maxOutputPrimitives = _.GetOutputPrimitivesEXT(entry_point);
|
||||
uint64_t max_output_primitives = _.GetOutputPrimitivesEXT(entry_point);
|
||||
uint32_t underlying_type = 0;
|
||||
if (spv_result_t error =
|
||||
GetUnderlyingType(_, decoration, inst, &underlying_type)) {
|
||||
return error;
|
||||
}
|
||||
|
||||
uint64_t primitiveArrayDim = 0;
|
||||
uint64_t primitive_array_dim = 0;
|
||||
if (_.GetIdOpcode(underlying_type) == spv::Op::OpTypeArray) {
|
||||
underlying_type = _.FindDef(underlying_type)->word(3u);
|
||||
if (!_.EvalConstantValUint64(underlying_type, &primitiveArrayDim)) {
|
||||
if (!_.EvalConstantValUint64(underlying_type, &primitive_array_dim)) {
|
||||
assert(0 && "Array type definition is corrupt");
|
||||
}
|
||||
}
|
||||
@@ -4419,7 +4660,8 @@ spv_result_t BuiltInsValidator::ValidateMeshShadingEXTBuiltinsAtDefinition(
|
||||
"with "
|
||||
"the OutputPoints Execution Mode. ";
|
||||
}
|
||||
if (primitiveArrayDim && primitiveArrayDim != maxOutputPrimitives) {
|
||||
if (primitive_array_dim &&
|
||||
primitive_array_dim != max_output_primitives) {
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, &inst)
|
||||
<< _.VkErrorID(7046)
|
||||
<< "The size of the array decorated with "
|
||||
@@ -4435,7 +4677,8 @@ spv_result_t BuiltInsValidator::ValidateMeshShadingEXTBuiltinsAtDefinition(
|
||||
"with "
|
||||
"the OutputLinesEXT Execution Mode. ";
|
||||
}
|
||||
if (primitiveArrayDim && primitiveArrayDim != maxOutputPrimitives) {
|
||||
if (primitive_array_dim &&
|
||||
primitive_array_dim != max_output_primitives) {
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, &inst)
|
||||
<< _.VkErrorID(7052)
|
||||
<< "The size of the array decorated with "
|
||||
@@ -4451,7 +4694,8 @@ spv_result_t BuiltInsValidator::ValidateMeshShadingEXTBuiltinsAtDefinition(
|
||||
"with "
|
||||
"the OutputTrianglesEXT Execution Mode. ";
|
||||
}
|
||||
if (primitiveArrayDim && primitiveArrayDim != maxOutputPrimitives) {
|
||||
if (primitive_array_dim &&
|
||||
primitive_array_dim != max_output_primitives) {
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, &inst)
|
||||
<< _.VkErrorID(7058)
|
||||
<< "The size of the array decorated with "
|
||||
@@ -4692,6 +4936,7 @@ spv_result_t BuiltInsValidator::ValidateSingleBuiltInAtDefinitionVulkan(
|
||||
case spv::BuiltIn::CullMaskKHR: {
|
||||
return ValidateRayTracingBuiltinsAtDefinition(decoration, inst);
|
||||
}
|
||||
// These are only for Mesh, not Task execution model
|
||||
case spv::BuiltIn::CullPrimitiveEXT:
|
||||
case spv::BuiltIn::PrimitivePointIndicesEXT:
|
||||
case spv::BuiltIn::PrimitiveLineIndicesEXT:
|
||||
|
||||
@@ -345,11 +345,18 @@ bool IsEnabledByCapabilityOpenCL_2_0(ValidationState_t& _,
|
||||
// Validates that capability declarations use operands allowed in the current
|
||||
// context.
|
||||
spv_result_t CapabilityPass(ValidationState_t& _, const Instruction* inst) {
|
||||
if (inst->opcode() != spv::Op::OpCapability) return SPV_SUCCESS;
|
||||
if (inst->opcode() != spv::Op::OpCapability &&
|
||||
inst->opcode() != spv::Op::OpConditionalCapabilityINTEL)
|
||||
return SPV_SUCCESS;
|
||||
|
||||
assert(inst->operands().size() == 1);
|
||||
assert(!((inst->opcode() == spv::Op::OpCapability) ^
|
||||
(inst->operands().size() == 1)));
|
||||
assert(!((inst->opcode() == spv::Op::OpConditionalCapabilityINTEL) ^
|
||||
(inst->operands().size() == 2)));
|
||||
|
||||
const spv_parsed_operand_t& operand = inst->operand(0);
|
||||
const uint32_t i_cap =
|
||||
inst->opcode() == spv::Op::OpConditionalCapabilityINTEL ? 1 : 0;
|
||||
const spv_parsed_operand_t& operand = inst->operand(i_cap);
|
||||
|
||||
assert(operand.num_words == 1);
|
||||
assert(operand.offset < inst->words().size());
|
||||
|
||||
@@ -16,6 +16,8 @@
|
||||
|
||||
// Validates correctness of composite SPIR-V instructions.
|
||||
|
||||
#include <climits>
|
||||
|
||||
#include "source/opcode.h"
|
||||
#include "source/spirv_target_env.h"
|
||||
#include "source/val/instruction.h"
|
||||
@@ -618,8 +620,464 @@ spv_result_t ValidateCopyLogical(ValidationState_t& _,
|
||||
return SPV_SUCCESS;
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
spv_result_t ValidateCompositeConstructCoopMatQCOM(ValidationState_t& _,
|
||||
const Instruction* inst) {
|
||||
// Is the result of coop mat ?
|
||||
const auto result_type_inst = _.FindDef(inst->type_id());
|
||||
if (!result_type_inst ||
|
||||
result_type_inst->opcode() != spv::Op::OpTypeCooperativeMatrixKHR) {
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||
<< "Opcode " << spvOpcodeString(inst->opcode())
|
||||
<< " requires the result type be OpTypeCooperativeMatrixKHR";
|
||||
}
|
||||
|
||||
const auto source = _.FindDef(inst->GetOperandAs<uint32_t>(2u));
|
||||
const auto source_type_inst = _.FindDef(source->type_id());
|
||||
|
||||
if (!source_type_inst || source_type_inst->opcode() != spv::Op::OpTypeArray) {
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||
<< "Opcode " << spvOpcodeString(inst->opcode())
|
||||
<< " requires the input operand be an OpTypeArray.";
|
||||
}
|
||||
|
||||
// Is the scope Subgrouop ?
|
||||
{
|
||||
unsigned scope = UINT_MAX;
|
||||
unsigned scope_id = result_type_inst->GetOperandAs<unsigned>(2u);
|
||||
bool status = _.GetConstantValueAs<unsigned>(scope_id, scope);
|
||||
bool is_scope_spec_const =
|
||||
spvOpcodeIsSpecConstant(_.FindDef(scope_id)->opcode());
|
||||
if (!is_scope_spec_const &&
|
||||
(!status || scope != static_cast<uint64_t>(spv::Scope::Subgroup))) {
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||
<< "Opcode " << spvOpcodeString(inst->opcode())
|
||||
<< " requires the result type's scope be Subgroup.";
|
||||
}
|
||||
}
|
||||
|
||||
unsigned ar_len = UINT_MAX;
|
||||
unsigned src_arr_len_id = source_type_inst->GetOperandAs<unsigned>(2u);
|
||||
bool ar_len_status = _.GetConstantValueAs<unsigned>(src_arr_len_id, ar_len);
|
||||
bool is_src_arr_len_spec_const =
|
||||
spvOpcodeIsSpecConstant(_.FindDef(src_arr_len_id)->opcode());
|
||||
|
||||
const auto source_elt_type = _.GetComponentType(source_type_inst->id());
|
||||
const auto result_elt_type = result_type_inst->GetOperandAs<uint32_t>(1u);
|
||||
|
||||
if ((source_elt_type != result_elt_type) &&
|
||||
!(_.ContainsSizedIntOrFloatType(source_elt_type, spv::Op::OpTypeInt,
|
||||
32) &&
|
||||
_.IsUnsignedIntScalarType(source_elt_type))) {
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||
<< "Opcode " << spvOpcodeString(inst->opcode())
|
||||
<< " requires ether the input element type is equal to the result "
|
||||
"element type or it is the unsigned 32-bit integer.";
|
||||
}
|
||||
|
||||
unsigned res_row_id = result_type_inst->GetOperandAs<unsigned>(3u);
|
||||
unsigned res_col_id = result_type_inst->GetOperandAs<unsigned>(4u);
|
||||
unsigned res_use_id = result_type_inst->GetOperandAs<unsigned>(5u);
|
||||
|
||||
unsigned cm_use = UINT_MAX;
|
||||
bool cm_use_status = _.GetConstantValueAs<unsigned>(res_use_id, cm_use);
|
||||
|
||||
switch (static_cast<spv::CooperativeMatrixUse>(cm_use)) {
|
||||
case spv::CooperativeMatrixUse::MatrixAKHR: {
|
||||
// result coopmat component type check
|
||||
if (!_.IsIntNOrFP32OrFP16<8>(result_elt_type)) {
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||
<< "Opcode " << spvOpcodeString(inst->opcode())
|
||||
<< " requires the result element type is one of 8-bit OpTypeInt "
|
||||
"signed/unsigned, 16- or 32-bit OpTypeFloat"
|
||||
<< " when result coopmat's use is MatrixAKHR";
|
||||
}
|
||||
|
||||
// result coopmat column length check
|
||||
unsigned n_cols = UINT_MAX;
|
||||
bool status = _.GetConstantValueAs<unsigned>(res_col_id, n_cols);
|
||||
bool is_res_col_spec_const =
|
||||
spvOpcodeIsSpecConstant(_.FindDef(res_col_id)->opcode());
|
||||
if (!is_res_col_spec_const &&
|
||||
(!status || (!(_.ContainsSizedIntOrFloatType(result_elt_type,
|
||||
spv::Op::OpTypeInt, 8) &&
|
||||
n_cols == 32) &&
|
||||
!(_.ContainsSizedIntOrFloatType(
|
||||
result_elt_type, spv::Op::OpTypeFloat, 16) &&
|
||||
n_cols == 16) &&
|
||||
!(_.ContainsSizedIntOrFloatType(
|
||||
result_elt_type, spv::Op::OpTypeFloat, 32) &&
|
||||
n_cols == 8)))) {
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||
<< "Opcode " << spvOpcodeString(inst->opcode())
|
||||
<< " requires the columns of the result coopmat have the bit "
|
||||
"length of 256"
|
||||
<< " when result coopmat's use is MatrixAKHR";
|
||||
}
|
||||
// source array length check
|
||||
if (!is_src_arr_len_spec_const &&
|
||||
(!ar_len_status ||
|
||||
(!(_.ContainsSizedIntOrFloatType(source_elt_type, spv::Op::OpTypeInt,
|
||||
32) &&
|
||||
_.IsUnsignedIntScalarType(source_elt_type) && (ar_len == 8)) &&
|
||||
!(n_cols == ar_len)))) {
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||
<< "Opcode " << spvOpcodeString(inst->opcode())
|
||||
<< " requires the source array length be 8 if its elt type is "
|
||||
"32-bit unsigned OpTypeInt and be the result's number of "
|
||||
"columns, otherwise"
|
||||
<< " when result coopmat's use is MatrixAKHR";
|
||||
}
|
||||
break;
|
||||
}
|
||||
case spv::CooperativeMatrixUse::MatrixBKHR: {
|
||||
// result coopmat component type check
|
||||
if (!_.IsIntNOrFP32OrFP16<8>(result_elt_type)) {
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||
<< "Opcode " << spvOpcodeString(inst->opcode())
|
||||
<< " requires the result element type is one of 8-bit OpTypeInt "
|
||||
"signed/unsigned, 16- or 32-bit OpTypeFloat"
|
||||
<< " when result coopmat's use is MatrixBKHR";
|
||||
}
|
||||
|
||||
// result coopmat row length check
|
||||
unsigned n_rows = UINT_MAX;
|
||||
bool status = _.GetConstantValueAs<unsigned>(res_row_id, n_rows);
|
||||
bool is_res_row_spec_const =
|
||||
spvOpcodeIsSpecConstant(_.FindDef(res_row_id)->opcode());
|
||||
if (!is_res_row_spec_const &&
|
||||
(!status || (!(_.ContainsSizedIntOrFloatType(result_elt_type,
|
||||
spv::Op::OpTypeInt, 8) &&
|
||||
n_rows == 32) &&
|
||||
!(_.ContainsSizedIntOrFloatType(
|
||||
result_elt_type, spv::Op::OpTypeFloat, 16) &&
|
||||
n_rows == 16) &&
|
||||
!(_.ContainsSizedIntOrFloatType(
|
||||
result_elt_type, spv::Op::OpTypeFloat, 32) &&
|
||||
n_rows == 8)))) {
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||
<< "Opcode " << spvOpcodeString(inst->opcode())
|
||||
<< " requires the rows of the result operand have the bit "
|
||||
"length of 256"
|
||||
<< " when result coopmat's use is MatrixBKHR";
|
||||
}
|
||||
// source array length check
|
||||
if (!is_src_arr_len_spec_const &&
|
||||
(!ar_len_status ||
|
||||
(!(_.ContainsSizedIntOrFloatType(source_elt_type, spv::Op::OpTypeInt,
|
||||
32) &&
|
||||
_.IsUnsignedIntScalarType(source_elt_type) && (ar_len == 8)) &&
|
||||
!(n_rows == ar_len)))) {
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||
<< "Opcode " << spvOpcodeString(inst->opcode())
|
||||
<< " requires the source array length be 8 if its elt type is "
|
||||
"32-bit unsigned OpTypeInt and be the result's number of "
|
||||
"rows, otherwise"
|
||||
<< " when result coopmat's use is MatrixBKHR";
|
||||
}
|
||||
break;
|
||||
}
|
||||
case spv::CooperativeMatrixUse::MatrixAccumulatorKHR: {
|
||||
// result coopmat component type check
|
||||
if (!_.IsIntNOrFP32OrFP16<32>(result_elt_type)) {
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||
<< "Opcode " << spvOpcodeString(inst->opcode())
|
||||
<< " requires the result element type is one of 32-bit "
|
||||
"OpTypeInt signed/unsigned, 16- or 32-bit OpTypeFloat"
|
||||
<< " when result coopmat's use is MatrixAccumulatorKHR";
|
||||
}
|
||||
|
||||
// source array length check
|
||||
unsigned n_cols = UINT_MAX;
|
||||
bool status = _.GetConstantValueAs<unsigned>(res_col_id, n_cols);
|
||||
bool is_res_col_spec_const =
|
||||
spvOpcodeIsSpecConstant(_.FindDef(res_col_id)->opcode());
|
||||
if (!is_res_col_spec_const && !is_src_arr_len_spec_const &&
|
||||
(!status || !ar_len_status ||
|
||||
(!(_.ContainsSizedIntOrFloatType(source_elt_type, spv::Op::OpTypeInt,
|
||||
32) &&
|
||||
_.IsUnsignedIntScalarType(source_elt_type) &&
|
||||
(_.ContainsSizedIntOrFloatType(result_elt_type,
|
||||
spv::Op::OpTypeFloat, 16)
|
||||
? (n_cols / 2 == ar_len)
|
||||
: n_cols == ar_len)) &&
|
||||
(n_cols != ar_len)))) {
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||
<< "Opcode " << spvOpcodeString(inst->opcode())
|
||||
<< " requires the source array length be a half of the number "
|
||||
"of columns of the resulting cooerative matrix if the "
|
||||
"matrix's componet type is 16-bit OpTypeFloat and be equal "
|
||||
"to the number of columns, otherwise,"
|
||||
<< " when result coopmat's use is MatrixAccumulatorKHR";
|
||||
}
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
bool is_cm_use_spec_const =
|
||||
spvOpcodeIsSpecConstant(_.FindDef(res_use_id)->opcode());
|
||||
if (!is_cm_use_spec_const || !cm_use_status) {
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||
<< "Opcode " << spvOpcodeString(inst->opcode())
|
||||
<< " requires the the resulting cooerative matrix's use be "
|
||||
<< " one of MatrixAKHR (== 0), MatrixBKHR (== 1), and "
|
||||
"MatrixAccumulatorKHR (== 2)";
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
return SPV_SUCCESS;
|
||||
}
|
||||
|
||||
spv_result_t ValidateCompositeExtractCoopMatQCOM(ValidationState_t& _,
|
||||
const Instruction* inst) {
|
||||
const auto result_type_inst = _.FindDef(inst->type_id());
|
||||
if (!result_type_inst || result_type_inst->opcode() != spv::Op::OpTypeArray) {
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||
<< "Opcode " << spvOpcodeString(inst->opcode())
|
||||
<< " requires the input operand be an OpTypeArray.";
|
||||
}
|
||||
|
||||
const auto source = _.FindDef(inst->GetOperandAs<uint32_t>(2u));
|
||||
const auto source_type_inst = _.FindDef(source->type_id());
|
||||
|
||||
// Is the source of coop mat ?
|
||||
if (!source_type_inst ||
|
||||
source_type_inst->opcode() != spv::Op::OpTypeCooperativeMatrixKHR) {
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||
<< "Opcode " << spvOpcodeString(inst->opcode())
|
||||
<< " requires the source type be OpTypeCooperativeMatrixKHR";
|
||||
}
|
||||
|
||||
// Is the scope Subgrouop ?
|
||||
{
|
||||
unsigned scope = UINT_MAX;
|
||||
unsigned scope_id = source_type_inst->GetOperandAs<unsigned>(2u);
|
||||
bool status = _.GetConstantValueAs<unsigned>(scope_id, scope);
|
||||
bool is_scope_spec_const =
|
||||
spvOpcodeIsSpecConstant(_.FindDef(scope_id)->opcode());
|
||||
if (!is_scope_spec_const &&
|
||||
(!status || scope != static_cast<uint64_t>(spv::Scope::Subgroup))) {
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||
<< "Opcode " << spvOpcodeString(inst->opcode())
|
||||
<< " requires the source type's scope be Subgroup.";
|
||||
}
|
||||
}
|
||||
|
||||
unsigned ar_len = UINT_MAX;
|
||||
unsigned res_arr_len_id = result_type_inst->GetOperandAs<unsigned>(2u);
|
||||
bool ar_len_status = _.GetConstantValueAs<unsigned>(res_arr_len_id, ar_len);
|
||||
bool is_res_arr_len_spec_const =
|
||||
spvOpcodeIsSpecConstant(_.FindDef(res_arr_len_id)->opcode());
|
||||
|
||||
const auto source_elt_type = _.GetComponentType(source_type_inst->id());
|
||||
const auto result_elt_type = result_type_inst->GetOperandAs<uint32_t>(1u);
|
||||
|
||||
unsigned src_row_id = source_type_inst->GetOperandAs<unsigned>(3u);
|
||||
unsigned src_col_id = source_type_inst->GetOperandAs<unsigned>(4u);
|
||||
unsigned src_use_id = source_type_inst->GetOperandAs<unsigned>(5u);
|
||||
|
||||
unsigned cm_use = UINT_MAX;
|
||||
bool cm_use_status = _.GetConstantValueAs<unsigned>(src_use_id, cm_use);
|
||||
|
||||
switch (static_cast<spv::CooperativeMatrixUse>(cm_use)) {
|
||||
case spv::CooperativeMatrixUse::MatrixAKHR: {
|
||||
// source coopmat component type check
|
||||
if (!_.IsIntNOrFP32OrFP16<8>(source_elt_type)) {
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||
<< "Opcode " << spvOpcodeString(inst->opcode())
|
||||
<< " requires the source element type be one of 8-bit OpTypeInt "
|
||||
"signed/unsigned, 16- or 32-bit OpTypeFloat"
|
||||
<< " when source coopmat's use is MatrixAKHR";
|
||||
}
|
||||
|
||||
// source coopmat column length check
|
||||
unsigned n_cols = UINT_MAX;
|
||||
bool status = _.GetConstantValueAs<unsigned>(src_col_id, n_cols);
|
||||
bool is_src_col_spec_const =
|
||||
spvOpcodeIsSpecConstant(_.FindDef(src_col_id)->opcode());
|
||||
if (!is_src_col_spec_const &&
|
||||
(!status || (!(_.ContainsSizedIntOrFloatType(source_elt_type,
|
||||
spv::Op::OpTypeInt, 8) &&
|
||||
n_cols == 32) &&
|
||||
!(_.ContainsSizedIntOrFloatType(
|
||||
source_elt_type, spv::Op::OpTypeFloat, 16) &&
|
||||
n_cols == 16) &&
|
||||
!(_.ContainsSizedIntOrFloatType(
|
||||
source_elt_type, spv::Op::OpTypeFloat, 32) &&
|
||||
n_cols == 8)))) {
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||
<< "Opcode " << spvOpcodeString(inst->opcode())
|
||||
<< " requires the columns of the source coopmat have the bit "
|
||||
"length of 256"
|
||||
<< " when source coopmat's use is MatrixAKHR";
|
||||
}
|
||||
// result type check
|
||||
if (!is_res_arr_len_spec_const &&
|
||||
!(source_elt_type == result_elt_type && (n_cols == ar_len)) &&
|
||||
!(_.ContainsSizedIntOrFloatType(result_elt_type, spv::Op::OpTypeInt,
|
||||
32) &&
|
||||
_.IsUnsignedIntScalarType(result_elt_type) && (ar_len == 8))) {
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||
<< "Opcode " << spvOpcodeString(inst->opcode())
|
||||
<< " requires either the result element type be the same as the "
|
||||
"source cooperative matrix's component type"
|
||||
<< " and its length be the same as the number of columns of the "
|
||||
"matrix or the result element type be"
|
||||
<< " unsigned 32-bit OpTypeInt and the length be 8"
|
||||
<< " when source coopmat's use is MatrixAKHR";
|
||||
}
|
||||
break;
|
||||
}
|
||||
case spv::CooperativeMatrixUse::MatrixBKHR: {
|
||||
// source coopmat component type check
|
||||
if (!_.IsIntNOrFP32OrFP16<8>(source_elt_type)) {
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||
<< "Opcode " << spvOpcodeString(inst->opcode())
|
||||
<< " requires the source element type be one of 8-bit OpTypeInt "
|
||||
"signed/unsigned, 16- or 32-bit OpTypeFloat"
|
||||
<< " when source coopmat's use is MatrixBKHR";
|
||||
}
|
||||
|
||||
// source coopmat row length check
|
||||
unsigned n_rows = UINT_MAX;
|
||||
bool status = _.GetConstantValueAs<unsigned>(src_row_id, n_rows);
|
||||
bool is_src_row_spec_const =
|
||||
spvOpcodeIsSpecConstant(_.FindDef(src_row_id)->opcode());
|
||||
if (!is_src_row_spec_const &&
|
||||
(!status || (!(_.ContainsSizedIntOrFloatType(source_elt_type,
|
||||
spv::Op::OpTypeInt, 8) &&
|
||||
n_rows == 32) &&
|
||||
!(_.ContainsSizedIntOrFloatType(
|
||||
source_elt_type, spv::Op::OpTypeFloat, 16) &&
|
||||
n_rows == 16) &&
|
||||
!(_.ContainsSizedIntOrFloatType(
|
||||
source_elt_type, spv::Op::OpTypeFloat, 32) &&
|
||||
n_rows == 8)))) {
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||
<< "Opcode " << spvOpcodeString(inst->opcode())
|
||||
<< " requires the rows of the source coopmat have the bit "
|
||||
"length of 256"
|
||||
<< " when source coopmat's use is MatrixBKHR";
|
||||
}
|
||||
// result type check
|
||||
if (!is_res_arr_len_spec_const &&
|
||||
!(source_elt_type == result_elt_type && (n_rows == ar_len)) &&
|
||||
!(_.ContainsSizedIntOrFloatType(result_elt_type, spv::Op::OpTypeInt,
|
||||
32) &&
|
||||
_.IsUnsignedIntScalarType(result_elt_type) && (ar_len == 8))) {
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||
<< "Opcode " << spvOpcodeString(inst->opcode())
|
||||
<< " requires either the result element type be the same as the "
|
||||
"source cooperative matrix's component type"
|
||||
<< " and its length be the same as the number of rows of the "
|
||||
"matrix or the result element type be"
|
||||
<< " unsigned 32-bit OpTypeInt and the length be 8"
|
||||
<< " when source coopmat's use is MatrixBKHR";
|
||||
}
|
||||
break;
|
||||
}
|
||||
case spv::CooperativeMatrixUse::MatrixAccumulatorKHR: {
|
||||
// source coopmat component type check
|
||||
if (!_.IsIntNOrFP32OrFP16<32>(source_elt_type)) {
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||
<< "Opcode " << spvOpcodeString(inst->opcode())
|
||||
<< " requires the source element type be one of 32-bit "
|
||||
"OpTypeInt signed/unsigned, 16- or 32-bit OpTypeFloat"
|
||||
<< " when source coopmat's use is MatrixAccumulatorKHR";
|
||||
}
|
||||
|
||||
// result type check
|
||||
unsigned n_cols = UINT_MAX;
|
||||
bool status = _.GetConstantValueAs<unsigned>(src_col_id, n_cols);
|
||||
bool is_src_col_spec_const =
|
||||
spvOpcodeIsSpecConstant(_.FindDef(src_col_id)->opcode());
|
||||
if (!is_src_col_spec_const && !is_res_arr_len_spec_const &&
|
||||
(!status || !ar_len_status ||
|
||||
(!(source_elt_type == result_elt_type && (n_cols == ar_len)) &&
|
||||
!(_.ContainsSizedIntOrFloatType(result_elt_type, spv::Op::OpTypeInt,
|
||||
32) &&
|
||||
_.IsUnsignedIntScalarType(result_elt_type) &&
|
||||
(_.ContainsSizedIntOrFloatType(source_elt_type,
|
||||
spv::Op::OpTypeFloat, 16)
|
||||
? (n_cols / 2 == ar_len)
|
||||
: (n_cols == ar_len)))))) {
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||
<< "Opcode " << spvOpcodeString(inst->opcode())
|
||||
<< " requires either the result element type be the same as the "
|
||||
"source cooperative matrix's component type"
|
||||
<< " and its length be the same as the number of columns of the "
|
||||
"matrix or the result element type be"
|
||||
<< " unsigned 32-bit OpTypeInt and the length be the number of "
|
||||
"the columns of the matrix if its component"
|
||||
<< " type is 32-bit OpTypeFloat and be a half of the number of "
|
||||
"the columns of the matrix if its component"
|
||||
<< " type is 16-bit OpTypeFloat"
|
||||
<< " when source coopmat's use is MatrixAccumulatorKHR";
|
||||
}
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
bool is_cm_use_spec_const =
|
||||
spvOpcodeIsSpecConstant(_.FindDef(src_use_id)->opcode());
|
||||
if (!is_cm_use_spec_const || !cm_use_status) {
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||
<< "Opcode " << spvOpcodeString(inst->opcode())
|
||||
<< " requires the the source cooerative matrix's use be "
|
||||
<< " one of MatrixAKHR (== 0), MatrixBKHR (== 1), and "
|
||||
"MatrixAccumulatorKHR (== 2)";
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
return SPV_SUCCESS;
|
||||
}
|
||||
|
||||
spv_result_t ValidateExtractSubArrayQCOM(ValidationState_t& _,
|
||||
const Instruction* inst) {
|
||||
const auto result_type_inst = _.FindDef(inst->type_id());
|
||||
const auto source = _.FindDef(inst->GetOperandAs<uint32_t>(2u));
|
||||
const auto source_type_inst = _.FindDef(source->type_id());
|
||||
|
||||
// Are the input and the result arrays?
|
||||
if (result_type_inst->opcode() != spv::Op::OpTypeArray ||
|
||||
source_type_inst->opcode() != spv::Op::OpTypeArray) {
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||
<< "Opcode " << spvOpcodeString(inst->opcode())
|
||||
<< " requires OpTypeArray operands for the input and the result.";
|
||||
}
|
||||
|
||||
const auto source_elt_type = _.GetComponentType(source_type_inst->id());
|
||||
const auto result_elt_type = _.GetComponentType(result_type_inst->id());
|
||||
|
||||
// Do the input and result element types match?
|
||||
if (source_elt_type != result_elt_type) {
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||
<< "Opcode " << spvOpcodeString(inst->opcode())
|
||||
<< " requires the input and result element types match.";
|
||||
}
|
||||
|
||||
// Elt type must be one of int32_t/uint32_t/float32/float16
|
||||
if (!_.IsIntNOrFP32OrFP16<32>(source_elt_type)) {
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||
<< "Opcode " << spvOpcodeString(inst->opcode())
|
||||
<< " requires the element type be one of 32-bit OpTypeInt "
|
||||
"(signed/unsigned), 32-bit OpTypeFloat and 16-bit OpTypeFloat";
|
||||
}
|
||||
|
||||
const auto start_index = _.FindDef(inst->GetOperandAs<uint32_t>(3u));
|
||||
if (!start_index || !_.ContainsSizedIntOrFloatType(start_index->type_id(),
|
||||
spv::Op::OpTypeInt, 32)) {
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||
<< "Opcode " << spvOpcodeString(inst->opcode())
|
||||
<< " requires the type of the start index operand be 32-bit "
|
||||
"OpTypeInt";
|
||||
}
|
||||
|
||||
return SPV_SUCCESS;
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
// Validates correctness of composite instructions.
|
||||
spv_result_t CompositesPass(ValidationState_t& _, const Instruction* inst) {
|
||||
switch (inst->opcode()) {
|
||||
@@ -641,6 +1099,12 @@ spv_result_t CompositesPass(ValidationState_t& _, const Instruction* inst) {
|
||||
return ValidateTranspose(_, inst);
|
||||
case spv::Op::OpCopyLogical:
|
||||
return ValidateCopyLogical(_, inst);
|
||||
case spv::Op::OpCompositeConstructCoopMatQCOM:
|
||||
return ValidateCompositeConstructCoopMatQCOM(_, inst);
|
||||
case spv::Op::OpCompositeExtractCoopMatQCOM:
|
||||
return ValidateCompositeExtractCoopMatQCOM(_, inst);
|
||||
case spv::Op::OpExtractSubArrayQCOM:
|
||||
return ValidateExtractSubArrayQCOM(_, inst);
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
||||
@@ -14,6 +14,8 @@
|
||||
|
||||
// Validates correctness of conversion instructions.
|
||||
|
||||
#include <climits>
|
||||
|
||||
#include "source/opcode.h"
|
||||
#include "source/spirv_constant.h"
|
||||
#include "source/spirv_target_env.h"
|
||||
@@ -572,26 +574,38 @@ spv_result_t ConversionPass(ValidationState_t& _, const Instruction* inst) {
|
||||
if (result_is_pointer && !input_is_pointer && !input_is_int_scalar &&
|
||||
!(input_is_int_vector && input_has_int32))
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||
<< "Expected input to be a pointer, int scalar or 32-bit int "
|
||||
<< "In SPIR-V 1.5 or later (or with "
|
||||
"SPV_KHR_physical_storage_buffer), expected input to be a "
|
||||
"pointer, "
|
||||
"int scalar or 32-bit int "
|
||||
"vector if Result Type is pointer: "
|
||||
<< spvOpcodeString(opcode);
|
||||
|
||||
if (input_is_pointer && !result_is_pointer && !result_is_int_scalar &&
|
||||
!(result_is_int_vector && result_has_int32))
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||
<< "Pointer can only be converted to another pointer, int "
|
||||
<< "In SPIR-V 1.5 or later (or with "
|
||||
"SPV_KHR_physical_storage_buffer), pointer can only be "
|
||||
"converted to "
|
||||
"another pointer, int "
|
||||
"scalar or 32-bit int vector: "
|
||||
<< spvOpcodeString(opcode);
|
||||
} else {
|
||||
if (result_is_pointer && !input_is_pointer && !input_is_int_scalar)
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||
<< "Expected input to be a pointer or int scalar if Result "
|
||||
<< "In SPIR-V 1.4 or earlier (and without "
|
||||
"SPV_KHR_physical_storage_buffer), expected input to be a "
|
||||
"pointer "
|
||||
"or int scalar if Result "
|
||||
"Type is pointer: "
|
||||
<< spvOpcodeString(opcode);
|
||||
|
||||
if (input_is_pointer && !result_is_pointer && !result_is_int_scalar)
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||
<< "Pointer can only be converted to another pointer or int "
|
||||
<< "In SPIR-V 1.4 or earlier (and without "
|
||||
"SPV_KHR_physical_storage_buffer), pointer can only be "
|
||||
"converted "
|
||||
"to another pointer or int "
|
||||
"scalar: "
|
||||
<< spvOpcodeString(opcode);
|
||||
}
|
||||
@@ -664,6 +678,69 @@ spv_result_t ConversionPass(ValidationState_t& _, const Instruction* inst) {
|
||||
break;
|
||||
}
|
||||
|
||||
case spv::Op::OpBitCastArrayQCOM: {
|
||||
const auto result_type_inst = _.FindDef(inst->type_id());
|
||||
const auto source = _.FindDef(inst->GetOperandAs<uint32_t>(2u));
|
||||
const auto source_type_inst = _.FindDef(source->type_id());
|
||||
|
||||
// Are the input and the result arrays?
|
||||
if (result_type_inst->opcode() != spv::Op::OpTypeArray ||
|
||||
source_type_inst->opcode() != spv::Op::OpTypeArray) {
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||
<< "Opcode " << spvOpcodeString(inst->opcode())
|
||||
<< " requires OpTypeArray operands for the input and the "
|
||||
"result.";
|
||||
}
|
||||
|
||||
const auto source_elt_type = _.GetComponentType(source_type_inst->id());
|
||||
const auto result_elt_type = _.GetComponentType(result_type_inst->id());
|
||||
|
||||
if (!_.IsIntNOrFP32OrFP16<32>(source_elt_type)) {
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||
<< "Opcode " << spvOpcodeString(inst->opcode())
|
||||
<< " requires the source element type be one of 32-bit "
|
||||
"OpTypeInt "
|
||||
"(signed/unsigned), 32-bit OpTypeFloat and 16-bit "
|
||||
"OpTypeFloat";
|
||||
}
|
||||
|
||||
if (!_.IsIntNOrFP32OrFP16<32>(result_elt_type)) {
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||
<< "Opcode " << spvOpcodeString(inst->opcode())
|
||||
<< " requires the result element type be one of 32-bit "
|
||||
"OpTypeInt "
|
||||
"(signed/unsigned), 32-bit OpTypeFloat and 16-bit "
|
||||
"OpTypeFloat";
|
||||
}
|
||||
|
||||
unsigned src_arr_len_id = source_type_inst->GetOperandAs<unsigned>(2u);
|
||||
unsigned res_arr_len_id = result_type_inst->GetOperandAs<unsigned>(2u);
|
||||
|
||||
// Are the input and result element types compatible?
|
||||
unsigned src_arr_len = UINT_MAX, res_arr_len = UINT_MAX;
|
||||
bool src_arr_len_status =
|
||||
_.GetConstantValueAs<unsigned>(src_arr_len_id, src_arr_len);
|
||||
bool res_arr_len_status =
|
||||
_.GetConstantValueAs<unsigned>(res_arr_len_id, res_arr_len);
|
||||
|
||||
bool is_src_arr_len_spec_const =
|
||||
spvOpcodeIsSpecConstant(_.FindDef(src_arr_len_id)->opcode());
|
||||
bool is_res_arr_len_spec_const =
|
||||
spvOpcodeIsSpecConstant(_.FindDef(res_arr_len_id)->opcode());
|
||||
|
||||
unsigned source_bitlen = _.GetBitWidth(source_elt_type) * src_arr_len;
|
||||
unsigned result_bitlen = _.GetBitWidth(result_elt_type) * res_arr_len;
|
||||
if (!is_src_arr_len_spec_const && !is_res_arr_len_spec_const &&
|
||||
(!src_arr_len_status || !res_arr_len_status ||
|
||||
source_bitlen != result_bitlen)) {
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||
<< "Opcode " << spvOpcodeString(inst->opcode())
|
||||
<< " requires source and result types be compatible for "
|
||||
"conversion.";
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
||||
@@ -398,11 +398,29 @@ bool IsAlignedTo(uint32_t offset, uint32_t alignment) {
|
||||
return 0 == (offset % alignment);
|
||||
}
|
||||
|
||||
std::string getStorageClassString(spv::StorageClass sc) {
|
||||
switch (sc) {
|
||||
case spv::StorageClass::Uniform:
|
||||
return "Uniform";
|
||||
case spv::StorageClass::UniformConstant:
|
||||
return "UniformConstant";
|
||||
case spv::StorageClass::PushConstant:
|
||||
return "PushConstant";
|
||||
case spv::StorageClass::Workgroup:
|
||||
return "Workgroup";
|
||||
case spv::StorageClass::PhysicalStorageBuffer:
|
||||
return "PhysicalStorageBuffer";
|
||||
default:
|
||||
// Only other valid storage class in these checks
|
||||
return "StorageBuffer";
|
||||
}
|
||||
}
|
||||
|
||||
// Returns SPV_SUCCESS if the given struct satisfies standard layout rules for
|
||||
// Block or BufferBlocks in Vulkan. Otherwise emits a diagnostic and returns
|
||||
// something other than SPV_SUCCESS. Matrices inherit the specified column
|
||||
// or row major-ness.
|
||||
spv_result_t checkLayout(uint32_t struct_id, const char* storage_class_str,
|
||||
spv_result_t checkLayout(uint32_t struct_id, spv::StorageClass storage_class,
|
||||
const char* decoration_str, bool blockRules,
|
||||
bool scalar_block_layout, uint32_t incoming_offset,
|
||||
MemberConstraints& constraints,
|
||||
@@ -418,22 +436,48 @@ spv_result_t checkLayout(uint32_t struct_id, const char* storage_class_str,
|
||||
// is more permissive than relaxed layout.
|
||||
const bool relaxed_block_layout = vstate.IsRelaxedBlockLayout();
|
||||
|
||||
auto fail = [&vstate, struct_id, storage_class_str, decoration_str,
|
||||
blockRules, relaxed_block_layout,
|
||||
auto fail = [&vstate, struct_id, storage_class, decoration_str, blockRules,
|
||||
relaxed_block_layout,
|
||||
scalar_block_layout](uint32_t member_idx) -> DiagnosticStream {
|
||||
DiagnosticStream ds =
|
||||
std::move(vstate.diag(SPV_ERROR_INVALID_ID, vstate.FindDef(struct_id))
|
||||
<< "Structure id " << struct_id << " decorated as "
|
||||
<< decoration_str << " for variable in " << storage_class_str
|
||||
<< " storage class must follow "
|
||||
<< (scalar_block_layout
|
||||
? "scalar "
|
||||
: (relaxed_block_layout ? "relaxed " : "standard "))
|
||||
<< (blockRules ? "uniform buffer" : "storage buffer")
|
||||
<< " layout rules: member " << member_idx << " ");
|
||||
DiagnosticStream ds = std::move(
|
||||
vstate.diag(SPV_ERROR_INVALID_ID, vstate.FindDef(struct_id))
|
||||
<< "Structure id " << struct_id << " decorated as " << decoration_str
|
||||
<< " for variable in " << getStorageClassString(storage_class)
|
||||
<< " storage class must follow "
|
||||
<< (scalar_block_layout
|
||||
? "scalar "
|
||||
: (relaxed_block_layout ? "relaxed " : "standard "))
|
||||
<< (blockRules ? "uniform buffer" : "storage buffer")
|
||||
<< " layout rules: member " << member_idx << " ");
|
||||
return ds;
|
||||
};
|
||||
|
||||
// People often use spirv-val from Vulkan Validation Layers, it ends up
|
||||
// mapping the various block layout rules from the enabled feature. This
|
||||
// offers a hint to help the user understand possbily why things are not
|
||||
// working when the shader itself "seems" valid, but just was a lack of adding
|
||||
// a supported feature
|
||||
auto extra = [&vstate, scalar_block_layout, storage_class,
|
||||
relaxed_block_layout, blockRules]() {
|
||||
if (!scalar_block_layout) {
|
||||
if (storage_class == spv::StorageClass::Workgroup) {
|
||||
return vstate.MissingFeature(
|
||||
"workgroupMemoryExplicitLayoutScalarBlockLayout feature",
|
||||
"--workgroup-scalar-block-layout", true);
|
||||
} else if (!relaxed_block_layout) {
|
||||
return vstate.MissingFeature("VK_KHR_relaxed_block_layout extension",
|
||||
"--relax-block-layout", true);
|
||||
} else if (blockRules) {
|
||||
return vstate.MissingFeature("uniformBufferStandardLayout feature",
|
||||
"--uniform-buffer-standard-layout", true);
|
||||
} else {
|
||||
return vstate.MissingFeature("scalarBlockLayout feature",
|
||||
"--scalar-block-layout", true);
|
||||
}
|
||||
}
|
||||
return std::string("");
|
||||
};
|
||||
|
||||
// If we are checking the layout of untyped pointers or physical storage
|
||||
// buffer pointers, we may not actually have a struct here. Instead, pretend
|
||||
// we have a struct with a single member at offset 0.
|
||||
@@ -507,7 +551,7 @@ spv_result_t checkLayout(uint32_t struct_id, const char* storage_class_str,
|
||||
const auto size = getSize(id, constraint, constraints, vstate);
|
||||
// Check offset.
|
||||
if (offset == 0xffffffff)
|
||||
return fail(memberIdx) << "is missing an Offset decoration";
|
||||
return fail(memberIdx) << "is missing an Offset decoration" << extra();
|
||||
|
||||
if (opcode == spv::Op::OpTypeRuntimeArray &&
|
||||
ordered_member_idx != member_offsets.size() - 1) {
|
||||
@@ -524,42 +568,44 @@ spv_result_t checkLayout(uint32_t struct_id, const char* storage_class_str,
|
||||
const auto componentId = inst->words()[2];
|
||||
const auto scalar_alignment = getScalarAlignment(componentId, vstate);
|
||||
if (!IsAlignedTo(offset, scalar_alignment)) {
|
||||
return fail(memberIdx)
|
||||
<< "at offset " << offset
|
||||
<< " is not aligned to scalar element size " << scalar_alignment;
|
||||
return fail(memberIdx) << "at offset " << offset
|
||||
<< " is not aligned to scalar element size "
|
||||
<< scalar_alignment << extra();
|
||||
}
|
||||
} else {
|
||||
// Without relaxed block layout, the offset must be divisible by the
|
||||
// alignment requirement.
|
||||
if (!IsAlignedTo(offset, alignment)) {
|
||||
return fail(memberIdx)
|
||||
<< "at offset " << offset << " is not aligned to " << alignment;
|
||||
return fail(memberIdx) << "at offset " << offset
|
||||
<< " is not aligned to " << alignment << extra();
|
||||
}
|
||||
}
|
||||
if (offset < nextValidOffset)
|
||||
return fail(memberIdx) << "at offset " << offset
|
||||
<< " overlaps previous member ending at offset "
|
||||
<< nextValidOffset - 1;
|
||||
<< nextValidOffset - 1 << extra();
|
||||
if (!scalar_block_layout && relaxed_block_layout) {
|
||||
// Check improper straddle of vectors.
|
||||
if (spv::Op::OpTypeVector == opcode &&
|
||||
hasImproperStraddle(id, offset, constraint, constraints, vstate))
|
||||
return fail(memberIdx)
|
||||
<< "is an improperly straddling vector at offset " << offset;
|
||||
<< "is an improperly straddling vector at offset " << offset
|
||||
<< extra();
|
||||
}
|
||||
// Check struct members recursively.
|
||||
spv_result_t recursive_status = SPV_SUCCESS;
|
||||
if (spv::Op::OpTypeStruct == opcode &&
|
||||
SPV_SUCCESS != (recursive_status = checkLayout(
|
||||
id, storage_class_str, decoration_str, blockRules,
|
||||
id, storage_class, decoration_str, blockRules,
|
||||
scalar_block_layout, offset, constraints, vstate)))
|
||||
return recursive_status;
|
||||
// Check matrix stride.
|
||||
if (spv::Op::OpTypeMatrix == opcode) {
|
||||
const auto stride = constraint.matrix_stride;
|
||||
if (!IsAlignedTo(stride, alignment)) {
|
||||
return fail(memberIdx) << "is a matrix with stride " << stride
|
||||
<< " not satisfying alignment to " << alignment;
|
||||
return fail(memberIdx)
|
||||
<< "is a matrix with stride " << stride
|
||||
<< " not satisfying alignment to " << alignment << extra();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -576,12 +622,13 @@ spv_result_t checkLayout(uint32_t struct_id, const char* storage_class_str,
|
||||
if (spv::Decoration::ArrayStride == decoration.dec_type()) {
|
||||
array_stride = decoration.params()[0];
|
||||
if (array_stride == 0) {
|
||||
return fail(memberIdx) << "contains an array with stride 0";
|
||||
return fail(memberIdx)
|
||||
<< "contains an array with stride 0" << extra();
|
||||
}
|
||||
if (!IsAlignedTo(array_stride, array_alignment))
|
||||
return fail(memberIdx)
|
||||
<< "contains an array with stride " << decoration.params()[0]
|
||||
<< " not satisfying alignment to " << alignment;
|
||||
<< " not satisfying alignment to " << alignment << extra();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -608,7 +655,7 @@ spv_result_t checkLayout(uint32_t struct_id, const char* storage_class_str,
|
||||
|
||||
if (SPV_SUCCESS !=
|
||||
(recursive_status = checkLayout(
|
||||
typeId, storage_class_str, decoration_str, blockRules,
|
||||
typeId, storage_class, decoration_str, blockRules,
|
||||
scalar_block_layout, next_offset, constraints, vstate)))
|
||||
return recursive_status;
|
||||
|
||||
@@ -620,7 +667,7 @@ spv_result_t checkLayout(uint32_t struct_id, const char* storage_class_str,
|
||||
if (!IsAlignedTo(stride, alignment)) {
|
||||
return fail(memberIdx)
|
||||
<< "is a matrix with stride " << stride
|
||||
<< " not satisfying alignment to " << alignment;
|
||||
<< " not satisfying alignment to " << alignment << extra();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -636,7 +683,7 @@ spv_result_t checkLayout(uint32_t struct_id, const char* storage_class_str,
|
||||
if (element_size > array_stride) {
|
||||
return fail(memberIdx)
|
||||
<< "contains an array with stride " << array_stride
|
||||
<< ", but with an element size of " << element_size;
|
||||
<< ", but with an element size of " << element_size << extra();
|
||||
}
|
||||
}
|
||||
nextValidOffset = offset + size;
|
||||
@@ -801,32 +848,35 @@ spv_result_t CheckDecorationsOfEntryPoints(ValidationState_t& vstate) {
|
||||
if (storage_class == spv::StorageClass::TaskPayloadWorkgroupEXT) {
|
||||
if (has_task_payload) {
|
||||
return vstate.diag(SPV_ERROR_INVALID_ID, var_instr)
|
||||
<< "There can be at most one OpVariable with storage "
|
||||
<< "There can be at most one "
|
||||
"OpVariable with storage "
|
||||
"class TaskPayloadWorkgroupEXT associated with "
|
||||
"an OpEntryPoint";
|
||||
}
|
||||
has_task_payload = true;
|
||||
}
|
||||
}
|
||||
if (vstate.version() >= SPV_SPIRV_VERSION_WORD(1, 4)) {
|
||||
|
||||
// Starting in 1.4, OpEntryPoint must list all global variables
|
||||
// it statically uses and those interfaces must be unique.
|
||||
if (storage_class == spv::StorageClass::Function) {
|
||||
return vstate.diag(SPV_ERROR_INVALID_ID, var_instr)
|
||||
<< "OpEntryPoint interfaces should only list global "
|
||||
<< "In SPIR-V 1.4 or later, OpEntryPoint interfaces should "
|
||||
"only list global "
|
||||
"variables";
|
||||
}
|
||||
|
||||
if (!seen_vars.insert(var_instr).second) {
|
||||
return vstate.diag(SPV_ERROR_INVALID_ID, var_instr)
|
||||
<< "Non-unique OpEntryPoint interface "
|
||||
<< "In SPIR-V 1.4 or later, non-unique OpEntryPoint "
|
||||
"interface "
|
||||
<< vstate.getIdName(interface) << " is disallowed";
|
||||
}
|
||||
} else {
|
||||
if (storage_class != spv::StorageClass::Input &&
|
||||
storage_class != spv::StorageClass::Output) {
|
||||
return vstate.diag(SPV_ERROR_INVALID_ID, var_instr)
|
||||
<< "OpEntryPoint interfaces must be OpVariables with "
|
||||
<< "In SPIR-V 1.3 or earlier, OpEntryPoint interfaces must "
|
||||
"be OpVariables with "
|
||||
"Storage Class of Input(1) or Output(3). Found Storage "
|
||||
"Class "
|
||||
<< uint32_t(storage_class) << " for Entry Point id "
|
||||
@@ -1129,6 +1179,56 @@ void ComputeMemberConstraintsForArray(MemberConstraints* constraints,
|
||||
}
|
||||
}
|
||||
|
||||
spv_result_t CheckDecorationsOfVariables(ValidationState_t& vstate) {
|
||||
if (!spvIsVulkanEnv(vstate.context()->target_env)) {
|
||||
return SPV_SUCCESS;
|
||||
}
|
||||
for (const auto& inst : vstate.ordered_instructions()) {
|
||||
if ((spv::Op::OpVariable == inst.opcode()) ||
|
||||
(spv::Op::OpUntypedVariableKHR == inst.opcode())) {
|
||||
const auto var_id = inst.id();
|
||||
const auto storageClass = inst.GetOperandAs<spv::StorageClass>(2);
|
||||
const bool uniform = storageClass == spv::StorageClass::Uniform;
|
||||
const bool uniform_constant =
|
||||
storageClass == spv::StorageClass::UniformConstant;
|
||||
const bool storage_buffer =
|
||||
storageClass == spv::StorageClass::StorageBuffer;
|
||||
|
||||
const char* sc_str = uniform ? "Uniform"
|
||||
: uniform_constant ? "UniformConstant"
|
||||
: "StorageBuffer";
|
||||
// Check variables in the UniformConstant, StorageBuffer, and Uniform
|
||||
// storage classes are decorated with DescriptorSet and Binding
|
||||
// (VUID-06677).
|
||||
if (uniform_constant || storage_buffer || uniform) {
|
||||
// Skip validation if the variable is not used and we're looking
|
||||
// at a module coming from HLSL that has not been legalized yet.
|
||||
if (vstate.options()->before_hlsl_legalization &&
|
||||
vstate.EntryPointReferences(var_id).empty()) {
|
||||
continue;
|
||||
}
|
||||
if (!hasDecoration(var_id, spv::Decoration::DescriptorSet, vstate)) {
|
||||
return vstate.diag(SPV_ERROR_INVALID_ID, vstate.FindDef(var_id))
|
||||
<< vstate.VkErrorID(6677) << sc_str << " id '" << var_id
|
||||
<< "' is missing DescriptorSet decoration.\n"
|
||||
<< "From Vulkan spec:\n"
|
||||
<< "These variables must have DescriptorSet and Binding "
|
||||
"decorations specified";
|
||||
}
|
||||
if (!hasDecoration(var_id, spv::Decoration::Binding, vstate)) {
|
||||
return vstate.diag(SPV_ERROR_INVALID_ID, vstate.FindDef(var_id))
|
||||
<< vstate.VkErrorID(6677) << sc_str << " id '" << var_id
|
||||
<< "' is missing Binding decoration.\n"
|
||||
<< "From Vulkan spec:\n"
|
||||
<< "These variables must have DescriptorSet and Binding "
|
||||
"decorations specified";
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return SPV_SUCCESS;
|
||||
}
|
||||
|
||||
spv_result_t CheckDecorationsOfBuffers(ValidationState_t& vstate) {
|
||||
// Set of entry points that are known to use a push constant.
|
||||
std::unordered_set<uint32_t> uses_push_constant;
|
||||
@@ -1148,8 +1248,6 @@ spv_result_t CheckDecorationsOfBuffers(ValidationState_t& vstate) {
|
||||
const auto storageClassVal = words[3];
|
||||
const auto storageClass = spv::StorageClass(storageClassVal);
|
||||
const bool uniform = storageClass == spv::StorageClass::Uniform;
|
||||
const bool uniform_constant =
|
||||
storageClass == spv::StorageClass::UniformConstant;
|
||||
const bool push_constant =
|
||||
storageClass == spv::StorageClass::PushConstant;
|
||||
const bool storage_buffer =
|
||||
@@ -1172,29 +1270,6 @@ spv_result_t CheckDecorationsOfBuffers(ValidationState_t& vstate) {
|
||||
}
|
||||
}
|
||||
}
|
||||
// Vulkan: Check DescriptorSet and Binding decoration for
|
||||
// UniformConstant which cannot be a struct.
|
||||
if (uniform_constant) {
|
||||
auto entry_points = vstate.EntryPointReferences(var_id);
|
||||
if (!entry_points.empty() &&
|
||||
!hasDecoration(var_id, spv::Decoration::DescriptorSet, vstate)) {
|
||||
return vstate.diag(SPV_ERROR_INVALID_ID, vstate.FindDef(var_id))
|
||||
<< vstate.VkErrorID(6677) << "UniformConstant id '" << var_id
|
||||
<< "' is missing DescriptorSet decoration.\n"
|
||||
<< "From Vulkan spec:\n"
|
||||
<< "These variables must have DescriptorSet and Binding "
|
||||
"decorations specified";
|
||||
}
|
||||
if (!entry_points.empty() &&
|
||||
!hasDecoration(var_id, spv::Decoration::Binding, vstate)) {
|
||||
return vstate.diag(SPV_ERROR_INVALID_ID, vstate.FindDef(var_id))
|
||||
<< vstate.VkErrorID(6677) << "UniformConstant id '" << var_id
|
||||
<< "' is missing Binding decoration.\n"
|
||||
<< "From Vulkan spec:\n"
|
||||
<< "These variables must have DescriptorSet and Binding "
|
||||
"decorations specified";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (spvIsOpenGLEnv(vstate.context()->target_env)) {
|
||||
@@ -1207,8 +1282,8 @@ spv_result_t CheckDecorationsOfBuffers(ValidationState_t& vstate) {
|
||||
if (!entry_points.empty() &&
|
||||
!hasDecoration(var_id, spv::Decoration::Binding, vstate)) {
|
||||
return vstate.diag(SPV_ERROR_INVALID_ID, vstate.FindDef(var_id))
|
||||
<< (uniform ? "Uniform" : "Storage Buffer") << " id '"
|
||||
<< var_id << "' is missing Binding decoration.\n"
|
||||
<< getStorageClassString(storageClass) << " id '" << var_id
|
||||
<< "' is missing Binding decoration.\n"
|
||||
<< "From ARB_gl_spirv extension:\n"
|
||||
<< "Uniform and shader storage block variables must "
|
||||
<< "also be decorated with a *Binding*.";
|
||||
@@ -1243,12 +1318,6 @@ spv_result_t CheckDecorationsOfBuffers(ValidationState_t& vstate) {
|
||||
ComputeMemberConstraintsForStruct(&constraints, id,
|
||||
LayoutConstraints(), vstate);
|
||||
}
|
||||
// Prepare for messages
|
||||
const char* sc_str =
|
||||
uniform
|
||||
? "Uniform"
|
||||
: (push_constant ? "PushConstant"
|
||||
: (workgroup ? "Workgroup" : "StorageBuffer"));
|
||||
|
||||
if (spvIsVulkanEnv(vstate.context()->target_env)) {
|
||||
const bool block = hasDecoration(id, spv::Decoration::Block, vstate);
|
||||
@@ -1286,30 +1355,6 @@ spv_result_t CheckDecorationsOfBuffers(ValidationState_t& vstate) {
|
||||
<< "Such variables must be identified with a Block or "
|
||||
"BufferBlock decoration";
|
||||
}
|
||||
// Vulkan: Check DescriptorSet and Binding decoration for
|
||||
// Uniform and StorageBuffer variables.
|
||||
if (uniform || storage_buffer) {
|
||||
auto entry_points = vstate.EntryPointReferences(var_id);
|
||||
if (!entry_points.empty() &&
|
||||
!hasDecoration(var_id, spv::Decoration::DescriptorSet,
|
||||
vstate)) {
|
||||
return vstate.diag(SPV_ERROR_INVALID_ID, vstate.FindDef(var_id))
|
||||
<< vstate.VkErrorID(6677) << sc_str << " id '" << var_id
|
||||
<< "' is missing DescriptorSet decoration.\n"
|
||||
<< "From Vulkan spec:\n"
|
||||
<< "These variables must have DescriptorSet and Binding "
|
||||
"decorations specified";
|
||||
}
|
||||
if (!entry_points.empty() &&
|
||||
!hasDecoration(var_id, spv::Decoration::Binding, vstate)) {
|
||||
return vstate.diag(SPV_ERROR_INVALID_ID, vstate.FindDef(var_id))
|
||||
<< vstate.VkErrorID(6677) << sc_str << " id '" << var_id
|
||||
<< "' is missing Binding decoration.\n"
|
||||
<< "From Vulkan spec:\n"
|
||||
<< "These variables must have DescriptorSet and Binding "
|
||||
"decorations specified";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (id != 0) {
|
||||
@@ -1386,14 +1431,14 @@ spv_result_t CheckDecorationsOfBuffers(ValidationState_t& vstate) {
|
||||
if (spvIsVulkanEnv(vstate.context()->target_env)) {
|
||||
if (blockRules &&
|
||||
(SPV_SUCCESS !=
|
||||
(recursive_status = checkLayout(id, sc_str, deco_str, true,
|
||||
scalar_block_layout, 0,
|
||||
constraints, vstate)))) {
|
||||
(recursive_status = checkLayout(
|
||||
id, storageClass, deco_str, true, scalar_block_layout,
|
||||
0, constraints, vstate)))) {
|
||||
return recursive_status;
|
||||
} else if (bufferRules &&
|
||||
(SPV_SUCCESS != (recursive_status = checkLayout(
|
||||
id, sc_str, deco_str, false,
|
||||
scalar_block_layout, 0,
|
||||
id, storageClass, deco_str,
|
||||
false, scalar_block_layout, 0,
|
||||
constraints, vstate)))) {
|
||||
return recursive_status;
|
||||
}
|
||||
@@ -1413,9 +1458,9 @@ spv_result_t CheckDecorationsOfBuffers(ValidationState_t& vstate) {
|
||||
ComputeMemberConstraintsForStruct(&constraints, pointee_type_id,
|
||||
LayoutConstraints(), vstate);
|
||||
}
|
||||
if (auto res = checkLayout(pointee_type_id, "PhysicalStorageBuffer",
|
||||
"Block", !buffer, scalar_block_layout, 0,
|
||||
constraints, vstate)) {
|
||||
if (auto res = checkLayout(
|
||||
pointee_type_id, spv::StorageClass::PhysicalStorageBuffer,
|
||||
"Block", !buffer, scalar_block_layout, 0, constraints, vstate)) {
|
||||
return res;
|
||||
}
|
||||
} else if (vstate.HasCapability(spv::Capability::UntypedPointersKHR) &&
|
||||
@@ -1464,14 +1509,6 @@ spv_result_t CheckDecorationsOfBuffers(ValidationState_t& vstate) {
|
||||
const auto sc =
|
||||
vstate.FindDef(ptr_ty_id)->GetOperandAs<spv::StorageClass>(1);
|
||||
|
||||
const char* sc_str =
|
||||
sc == spv::StorageClass::Uniform
|
||||
? "Uniform"
|
||||
: (sc == spv::StorageClass::PushConstant
|
||||
? "PushConstant"
|
||||
: (sc == spv::StorageClass::Workgroup ? "Workgroup"
|
||||
: "StorageBuffer"));
|
||||
|
||||
auto data_type = vstate.FindDef(data_type_id);
|
||||
scalar_block_layout =
|
||||
sc == spv::StorageClass::Workgroup
|
||||
@@ -1511,7 +1548,7 @@ spv_result_t CheckDecorationsOfBuffers(ValidationState_t& vstate) {
|
||||
? (sc == spv::StorageClass::Uniform ? "BufferBlock" : "Block")
|
||||
: "Block";
|
||||
if (auto result =
|
||||
checkLayout(data_type_id, sc_str, deco_str, !bufferRules,
|
||||
checkLayout(data_type_id, sc, deco_str, !bufferRules,
|
||||
scalar_block_layout, 0, constraints, vstate)) {
|
||||
return result;
|
||||
}
|
||||
@@ -1732,14 +1769,19 @@ spv_result_t CheckFPRoundingModeForShaders(ValidationState_t& vstate,
|
||||
return SPV_SUCCESS;
|
||||
}
|
||||
|
||||
// Returns SPV_SUCCESS if validation rules are satisfied for the NonWritable
|
||||
// Returns SPV_SUCCESS if validation rules are satisfied for the NonReadable or
|
||||
// NonWritable
|
||||
// decoration. Otherwise emits a diagnostic and returns something other than
|
||||
// SPV_SUCCESS. The |inst| parameter is the object being decorated. This must
|
||||
// be called after TypePass and AnnotateCheckDecorationsOfBuffers are called.
|
||||
spv_result_t CheckNonWritableDecoration(ValidationState_t& vstate,
|
||||
const Instruction& inst,
|
||||
const Decoration& decoration) {
|
||||
spv_result_t CheckNonReadableWritableDecorations(ValidationState_t& vstate,
|
||||
const Instruction& inst,
|
||||
const Decoration& decoration) {
|
||||
assert(inst.id() && "Parser ensures the target of the decoration has an ID");
|
||||
const bool is_non_writable =
|
||||
decoration.dec_type() == spv::Decoration::NonWritable;
|
||||
assert(is_non_writable ||
|
||||
decoration.dec_type() == spv::Decoration::NonReadable);
|
||||
|
||||
if (decoration.struct_member_index() == Decoration::kInvalidMember) {
|
||||
// The target must be a memory object declaration.
|
||||
@@ -1751,7 +1793,10 @@ spv_result_t CheckNonWritableDecoration(ValidationState_t& vstate,
|
||||
opcode != spv::Op::OpFunctionParameter &&
|
||||
opcode != spv::Op::OpRawAccessChainNV) {
|
||||
return vstate.diag(SPV_ERROR_INVALID_ID, &inst)
|
||||
<< "Target of NonWritable decoration must be a memory object "
|
||||
<< "Target of "
|
||||
<< (is_non_writable ? "NonWritable" : "NonReadable")
|
||||
<< " decoration must be a "
|
||||
"memory object "
|
||||
"declaration (a variable or a function parameter)";
|
||||
}
|
||||
const auto var_storage_class =
|
||||
@@ -1762,7 +1807,8 @@ spv_result_t CheckNonWritableDecoration(ValidationState_t& vstate,
|
||||
: spv::StorageClass::Max;
|
||||
if ((var_storage_class == spv::StorageClass::Function ||
|
||||
var_storage_class == spv::StorageClass::Private) &&
|
||||
vstate.features().nonwritable_var_in_function_or_private) {
|
||||
vstate.features().nonwritable_var_in_function_or_private &&
|
||||
is_non_writable) {
|
||||
// New permitted feature in SPIR-V 1.4.
|
||||
} else if (var_storage_class == spv::StorageClass::TileAttachmentQCOM) {
|
||||
} else if (
|
||||
@@ -1770,12 +1816,18 @@ spv_result_t CheckNonWritableDecoration(ValidationState_t& vstate,
|
||||
vstate.IsPointerToUniformBlock(type_id) ||
|
||||
vstate.IsPointerToStorageBuffer(type_id) ||
|
||||
vstate.IsPointerToStorageImage(type_id) ||
|
||||
vstate.IsPointerToTensor(type_id) ||
|
||||
opcode == spv::Op::OpRawAccessChainNV) {
|
||||
} else {
|
||||
return vstate.diag(SPV_ERROR_INVALID_ID, &inst)
|
||||
<< "Target of NonWritable decoration is invalid: must point to a "
|
||||
"storage image, uniform block, "
|
||||
<< (vstate.features().nonwritable_var_in_function_or_private
|
||||
<< "Target of "
|
||||
<< (is_non_writable ? "NonWritable" : "NonReadable")
|
||||
<< " decoration is invalid: "
|
||||
"must point to a "
|
||||
"storage image, tensor variable in UniformConstant storage "
|
||||
"class, uniform block, "
|
||||
<< (vstate.features().nonwritable_var_in_function_or_private &&
|
||||
is_non_writable
|
||||
? "storage buffer, or variable in Private or Function "
|
||||
"storage class"
|
||||
: "or storage buffer");
|
||||
@@ -2063,8 +2115,10 @@ spv_result_t CheckDecorationsFromDecoration(ValidationState_t& vstate) {
|
||||
PASS_OR_BAIL(
|
||||
CheckFPRoundingModeForShaders(vstate, *inst, decoration));
|
||||
break;
|
||||
case spv::Decoration::NonReadable:
|
||||
case spv::Decoration::NonWritable:
|
||||
PASS_OR_BAIL(CheckNonWritableDecoration(vstate, *inst, decoration));
|
||||
PASS_OR_BAIL(
|
||||
CheckNonReadableWritableDecorations(vstate, *inst, decoration));
|
||||
break;
|
||||
case spv::Decoration::Uniform:
|
||||
case spv::Decoration::UniformId:
|
||||
@@ -2298,6 +2352,7 @@ spv_result_t ValidateDecorations(ValidationState_t& vstate) {
|
||||
if (auto error = CheckImportedVariableInitialization(vstate)) return error;
|
||||
if (auto error = CheckDecorationsOfEntryPoints(vstate)) return error;
|
||||
if (auto error = CheckDecorationsOfBuffers(vstate)) return error;
|
||||
if (auto error = CheckDecorationsOfVariables(vstate)) return error;
|
||||
if (auto error = CheckDecorationsCompatibility(vstate)) return error;
|
||||
if (auto error = CheckLinkageAttrOfFunctions(vstate)) return error;
|
||||
if (auto error = CheckVulkanMemoryModelDeprecatedDecorations(vstate))
|
||||
|
||||
@@ -1052,7 +1052,9 @@ bool IsDebugVariableWithIntScalarType(ValidationState_t& _,
|
||||
spv_result_t ValidateExtension(ValidationState_t& _, const Instruction* inst) {
|
||||
std::string extension = GetExtensionString(&(inst->c_inst()));
|
||||
if (_.version() < SPV_SPIRV_VERSION_WORD(1, 3)) {
|
||||
if (extension == ExtensionToString(kSPV_KHR_vulkan_memory_model)) {
|
||||
if (extension == ExtensionToString(kSPV_KHR_vulkan_memory_model) ||
|
||||
extension ==
|
||||
ExtensionToString(kSPV_QCOM_cooperative_matrix_conversion)) {
|
||||
return _.diag(SPV_ERROR_WRONG_VERSION, inst)
|
||||
<< extension << " extension requires SPIR-V version 1.3 or later.";
|
||||
}
|
||||
@@ -1064,7 +1066,9 @@ spv_result_t ValidateExtension(ValidationState_t& _, const Instruction* inst) {
|
||||
extension == ExtensionToString(kSPV_NV_shader_invocation_reorder) ||
|
||||
extension ==
|
||||
ExtensionToString(kSPV_NV_cluster_acceleration_structure) ||
|
||||
extension == ExtensionToString(kSPV_NV_linear_swept_spheres)) {
|
||||
extension == ExtensionToString(kSPV_NV_linear_swept_spheres) ||
|
||||
extension == ExtensionToString(kSPV_QCOM_image_processing) ||
|
||||
extension == ExtensionToString(kSPV_QCOM_image_processing2)) {
|
||||
return _.diag(SPV_ERROR_WRONG_VERSION, inst)
|
||||
<< extension << " extension requires SPIR-V version 1.4 or later.";
|
||||
}
|
||||
@@ -1081,8 +1085,10 @@ spv_result_t ValidateExtInstImport(ValidationState_t& _,
|
||||
const std::string name = inst->GetOperandAs<std::string>(name_id);
|
||||
if (name.find("NonSemantic.") == 0) {
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||
<< "NonSemantic extended instruction sets cannot be declared "
|
||||
"without SPV_KHR_non_semantic_info.";
|
||||
<< "NonSemantic extended instruction "
|
||||
"sets cannot be declared "
|
||||
"without SPV_KHR_non_semantic_info. (This can also be fixed "
|
||||
"having SPIR-V 1.6 or later)";
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -89,7 +89,10 @@ spv_result_t ValidateFunction(ValidationState_t& _, const Instruction* inst) {
|
||||
spv::Op::OpName,
|
||||
spv::Op::OpCooperativeMatrixPerElementOpNV,
|
||||
spv::Op::OpCooperativeMatrixReduceNV,
|
||||
spv::Op::OpCooperativeMatrixLoadTensorNV};
|
||||
spv::Op::OpCooperativeMatrixLoadTensorNV,
|
||||
spv::Op::OpConditionalEntryPointINTEL,
|
||||
};
|
||||
|
||||
for (auto& pair : inst->uses()) {
|
||||
const auto* use = pair.first;
|
||||
if (std::find(acceptable.begin(), acceptable.end(), use->opcode()) ==
|
||||
@@ -109,11 +112,6 @@ spv_result_t ValidateFunctionParameter(ValidationState_t& _,
|
||||
// NOTE: Find OpFunction & ensure OpFunctionParameter is not out of place.
|
||||
size_t param_index = 0;
|
||||
size_t inst_num = inst->LineNum() - 1;
|
||||
if (inst_num == 0) {
|
||||
return _.diag(SPV_ERROR_INVALID_LAYOUT, inst)
|
||||
<< "Function parameter cannot be the first instruction.";
|
||||
}
|
||||
|
||||
auto func_inst = &_.ordered_instructions()[inst_num];
|
||||
while (--inst_num) {
|
||||
func_inst = &_.ordered_instructions()[inst_num];
|
||||
|
||||
547
3rdparty/spirv-tools/source/val/validate_graph.cpp
vendored
Normal file
547
3rdparty/spirv-tools/source/val/validate_graph.cpp
vendored
Normal file
@@ -0,0 +1,547 @@
|
||||
// Copyright (c) 2023-2025 Arm Ltd.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// Validates correctness of graph instructions.
|
||||
|
||||
#include <deque>
|
||||
|
||||
#include "source/opcode.h"
|
||||
#include "source/val/validate.h"
|
||||
#include "source/val/validation_state.h"
|
||||
|
||||
namespace spvtools {
|
||||
namespace val {
|
||||
namespace {
|
||||
|
||||
bool IsTensorArray(ValidationState_t& _, uint32_t id) {
|
||||
auto def = _.FindDef(id);
|
||||
if (!def || (def->opcode() != spv::Op::OpTypeArray &&
|
||||
def->opcode() != spv::Op::OpTypeRuntimeArray)) {
|
||||
return false;
|
||||
}
|
||||
auto tdef = _.FindDef(def->word(2));
|
||||
if (!tdef || tdef->opcode() != spv::Op::OpTypeTensorARM) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool IsGraphInterfaceType(ValidationState_t& _, uint32_t id) {
|
||||
return _.IsTensorType(id) || IsTensorArray(_, id);
|
||||
}
|
||||
|
||||
bool IsGraph(ValidationState_t& _, uint32_t id) {
|
||||
auto def = _.FindDef(id);
|
||||
if (!def || def->opcode() != spv::Op::OpGraphARM) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool IsGraphType(ValidationState_t& _, uint32_t id) {
|
||||
auto def = _.FindDef(id);
|
||||
if (!def || def->opcode() != spv::Op::OpTypeGraphARM) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
const uint32_t kGraphTypeIOStartWord = 3;
|
||||
|
||||
uint32_t GraphTypeInstNumIO(const Instruction* inst) {
|
||||
return static_cast<uint32_t>(inst->words().size()) - kGraphTypeIOStartWord;
|
||||
}
|
||||
|
||||
uint32_t GraphTypeInstNumInputs(const Instruction* inst) {
|
||||
return inst->word(2);
|
||||
}
|
||||
|
||||
uint32_t GraphTypeInstNumOutputs(const Instruction* inst) {
|
||||
return GraphTypeInstNumIO(inst) - GraphTypeInstNumInputs(inst);
|
||||
}
|
||||
|
||||
uint32_t GraphTypeInstGetOutputAtIndex(const Instruction* inst,
|
||||
uint64_t index) {
|
||||
return inst->word(kGraphTypeIOStartWord + GraphTypeInstNumInputs(inst) +
|
||||
static_cast<uint32_t>(index));
|
||||
}
|
||||
|
||||
uint32_t GraphTypeInstGetInputAtIndex(const Instruction* inst, uint64_t index) {
|
||||
return inst->word(kGraphTypeIOStartWord + static_cast<uint32_t>(index));
|
||||
}
|
||||
|
||||
spv_result_t ValidateGraphType(ValidationState_t& _, const Instruction* inst) {
|
||||
// Check there are at least NumInputs types
|
||||
uint32_t NumInputs = GraphTypeInstNumInputs(inst);
|
||||
size_t NumIOTypes = GraphTypeInstNumIO(inst);
|
||||
if (NumIOTypes < NumInputs) {
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||
<< NumIOTypes << " I/O types were provided but the graph has "
|
||||
<< NumInputs << " inputs.";
|
||||
}
|
||||
|
||||
// Check there is at least one output
|
||||
if (NumIOTypes == NumInputs) {
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||
<< "A graph type must have at least one output.";
|
||||
}
|
||||
|
||||
// Check all I/O types are graph interface type
|
||||
for (unsigned i = kGraphTypeIOStartWord; i < inst->words().size(); i++) {
|
||||
auto tid = inst->word(i);
|
||||
if (!IsGraphInterfaceType(_, tid)) {
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||
<< "I/O type " << _.getIdName(tid)
|
||||
<< " is not a Graph Interface Type.";
|
||||
}
|
||||
}
|
||||
|
||||
return SPV_SUCCESS;
|
||||
}
|
||||
|
||||
spv_result_t ValidateGraphConstant(ValidationState_t& _,
|
||||
const Instruction* inst) {
|
||||
// Check Result Type
|
||||
if (!_.IsTensorType(inst->type_id())) {
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||
<< spvOpcodeString(inst->opcode())
|
||||
<< " must have a Result Type that is a tensor type.";
|
||||
}
|
||||
|
||||
// Check the instruction is not preceded by another OpGraphConstantARM with
|
||||
// the same ID
|
||||
const uint32_t cst_id = inst->word(3);
|
||||
size_t inst_num = inst->LineNum() - 1;
|
||||
while (--inst_num) {
|
||||
auto prev_inst = &_.ordered_instructions()[inst_num];
|
||||
if (prev_inst->opcode() == spv::Op::OpGraphConstantARM) {
|
||||
const uint32_t prev_cst_id = prev_inst->word(3);
|
||||
if (prev_cst_id == cst_id) {
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||
<< "No two OpGraphConstantARM instructions may have the same "
|
||||
"GraphConstantID";
|
||||
}
|
||||
}
|
||||
}
|
||||
return SPV_SUCCESS;
|
||||
}
|
||||
|
||||
spv_result_t ValidateGraphEntryPoint(ValidationState_t& _,
|
||||
const Instruction* inst) {
|
||||
// Graph must be an OpGraphARM
|
||||
uint32_t graph = inst->GetOperandAs<uint32_t>(0);
|
||||
auto graph_inst = _.FindDef(graph);
|
||||
if (!IsGraph(_, graph)) {
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||
<< spvOpcodeString(inst->opcode())
|
||||
<< " Graph must be a OpGraphARM but found "
|
||||
<< spvOpcodeString(graph_inst->opcode()) << ".";
|
||||
}
|
||||
|
||||
// Check number of Interface IDs matches number of I/Os of graph
|
||||
auto graph_type_inst = _.FindDef(graph_inst->type_id());
|
||||
size_t graph_type_num_io = GraphTypeInstNumIO(graph_type_inst);
|
||||
size_t graph_entry_point_num_interface_id = inst->operands().size() - 2;
|
||||
if (graph_type_inst->opcode() != spv::Op::OpTypeGraphARM) {
|
||||
// This is invalid but we want ValidateGraph to report a clear error
|
||||
// so stop validating the graph entry point instruction
|
||||
return SPV_SUCCESS;
|
||||
}
|
||||
if (graph_type_num_io != graph_entry_point_num_interface_id) {
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||
<< spvOpcodeString(inst->opcode()) << " Interface list contains "
|
||||
<< graph_entry_point_num_interface_id << " IDs but Graph's type "
|
||||
<< _.getIdName(graph_inst->type_id()) << " has " << graph_type_num_io
|
||||
<< " inputs and outputs.";
|
||||
}
|
||||
|
||||
// Check Interface IDs
|
||||
for (uint32_t i = 2; i < inst->operands().size(); i++) {
|
||||
uint32_t interface_id = inst->GetOperandAs<uint32_t>(i);
|
||||
auto interface_inst = _.FindDef(interface_id);
|
||||
|
||||
// Check interface IDs come from OpVariable
|
||||
if ((interface_inst->opcode() != spv::Op::OpVariable) ||
|
||||
(interface_inst->GetOperandAs<spv::StorageClass>(2) !=
|
||||
spv::StorageClass::UniformConstant)) {
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, interface_inst)
|
||||
<< spvOpcodeString(inst->opcode()) << " Interface ID "
|
||||
<< _.getIdName(interface_id)
|
||||
<< " must come from OpVariable with UniformConstant Storage "
|
||||
"Class.";
|
||||
}
|
||||
|
||||
// Check type of interface variable matches type of the corresponding graph
|
||||
// I/O
|
||||
uint32_t corresponding_graph_io_type =
|
||||
graph_type_inst->GetOperandAs<uint32_t>(i);
|
||||
|
||||
uint32_t interface_ptr_type = interface_inst->type_id();
|
||||
auto interface_ptr_inst = _.FindDef(interface_ptr_type);
|
||||
auto interface_pointee_type = interface_ptr_inst->GetOperandAs<uint32_t>(2);
|
||||
if (interface_pointee_type != corresponding_graph_io_type) {
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||
<< spvOpcodeString(inst->opcode()) << " Interface ID type "
|
||||
<< _.getIdName(interface_pointee_type)
|
||||
<< " must match the type of the corresponding graph I/O "
|
||||
<< _.getIdName(corresponding_graph_io_type);
|
||||
}
|
||||
}
|
||||
|
||||
return SPV_SUCCESS;
|
||||
}
|
||||
|
||||
spv_result_t ValidateGraph(ValidationState_t& _, const Instruction* inst) {
|
||||
// Result Type must be an OpTypeGraphARM
|
||||
if (!IsGraphType(_, inst->type_id())) {
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||
<< spvOpcodeString(inst->opcode())
|
||||
<< " Result Type must be an OpTypeGraphARM.";
|
||||
}
|
||||
|
||||
return SPV_SUCCESS;
|
||||
}
|
||||
|
||||
spv_result_t ValidateGraphInput(ValidationState_t& _, const Instruction* inst) {
|
||||
// Check type of InputIndex
|
||||
auto input_index_inst = _.FindDef(inst->GetOperandAs<uint32_t>(2));
|
||||
if (!input_index_inst ||
|
||||
!_.IsIntScalarType(input_index_inst->type_id(), 32)) {
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||
<< spvOpcodeString(inst->opcode())
|
||||
<< " InputIndex must be a 32-bit integer.";
|
||||
}
|
||||
|
||||
bool has_element_index = inst->operands().size() > 3;
|
||||
|
||||
// Check type of ElementIndex
|
||||
if (has_element_index) {
|
||||
auto element_index_inst = _.FindDef(inst->GetOperandAs<uint32_t>(3));
|
||||
if (!element_index_inst ||
|
||||
!_.IsIntScalarType(element_index_inst->type_id(), 32)) {
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||
<< spvOpcodeString(inst->opcode())
|
||||
<< " ElementIndex must be a 32-bit integer.";
|
||||
}
|
||||
}
|
||||
|
||||
// Find graph definition
|
||||
size_t inst_num = inst->LineNum() - 1;
|
||||
auto graph_inst = &_.ordered_instructions()[inst_num];
|
||||
while (--inst_num) {
|
||||
graph_inst = &_.ordered_instructions()[inst_num];
|
||||
if (graph_inst->opcode() == spv::Op::OpGraphARM) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Can the InputIndex be evaluated?
|
||||
// If not, there's nothing more we can validate here.
|
||||
uint64_t input_index;
|
||||
if (!_.EvalConstantValUint64(inst->GetOperandAs<uint32_t>(2), &input_index)) {
|
||||
return SPV_SUCCESS;
|
||||
}
|
||||
|
||||
auto const graph_type_inst = _.FindDef(graph_inst->type_id());
|
||||
size_t graph_type_num_inputs = graph_type_inst->GetOperandAs<uint32_t>(1);
|
||||
|
||||
// Check InputIndex is in range
|
||||
if (input_index >= graph_type_num_inputs) {
|
||||
std::string disassembly = _.Disassemble(*inst);
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, nullptr)
|
||||
<< "Type " << _.getIdName(graph_type_inst->id()) << " for graph "
|
||||
<< _.getIdName(graph_inst->id()) << " has " << graph_type_num_inputs
|
||||
<< " inputs but found an OpGraphInputARM instruction with an "
|
||||
"InputIndex that is "
|
||||
<< input_index << ": " << disassembly;
|
||||
}
|
||||
|
||||
uint32_t graph_type_input_type =
|
||||
GraphTypeInstGetInputAtIndex(graph_type_inst, input_index);
|
||||
|
||||
if (has_element_index) {
|
||||
// Check ElementIndex is allowed
|
||||
if (!IsTensorArray(_, graph_type_input_type)) {
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||
<< "OpGraphInputARM ElementIndex not allowed when the graph input "
|
||||
"selected by "
|
||||
<< "InputIndex is not an OpTypeArray or OpTypeRuntimeArray";
|
||||
}
|
||||
|
||||
// Check ElementIndex is in range if it can be evaluated and the input is a
|
||||
// fixed-sized array whose Length can be evaluated
|
||||
uint64_t element_index;
|
||||
if (_.IsArrayType(graph_type_input_type) &&
|
||||
_.EvalConstantValUint64(inst->GetOperandAs<uint32_t>(3),
|
||||
&element_index)) {
|
||||
uint64_t array_length;
|
||||
auto graph_type_input_type_inst = _.FindDef(graph_type_input_type);
|
||||
if (_.EvalConstantValUint64(
|
||||
graph_type_input_type_inst->GetOperandAs<uint32_t>(2),
|
||||
&array_length)) {
|
||||
if (element_index >= array_length) {
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||
<< "OpGraphInputARM ElementIndex out of range. The type of "
|
||||
"the graph input being accessed "
|
||||
<< _.getIdName(graph_type_input_type) << " is an array of "
|
||||
<< array_length << " elements but " << "ElementIndex is "
|
||||
<< element_index;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check result type matches with graph type
|
||||
if (has_element_index) {
|
||||
uint32_t expected_type = _.GetComponentType(graph_type_input_type);
|
||||
if (inst->type_id() != expected_type) {
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||
<< "Result Type " << _.getIdName(inst->type_id())
|
||||
<< " of graph input instruction " << _.getIdName(inst->id())
|
||||
<< " does not match the component type "
|
||||
<< _.getIdName(expected_type) << " of input " << input_index
|
||||
<< " in the graph type.";
|
||||
}
|
||||
} else {
|
||||
if (inst->type_id() != graph_type_input_type) {
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||
<< "Result Type " << _.getIdName(inst->type_id())
|
||||
<< " of graph input instruction " << _.getIdName(inst->id())
|
||||
<< " does not match the type "
|
||||
<< _.getIdName(graph_type_input_type) << " of input "
|
||||
<< input_index << " in the graph type.";
|
||||
}
|
||||
}
|
||||
return SPV_SUCCESS;
|
||||
}
|
||||
|
||||
spv_result_t ValidateGraphSetOutput(ValidationState_t& _,
|
||||
const Instruction* inst) {
|
||||
// Check type of OutputIndex
|
||||
auto output_index_inst = _.FindDef(inst->GetOperandAs<uint32_t>(1));
|
||||
if (!output_index_inst ||
|
||||
!_.IsIntScalarType(output_index_inst->type_id(), 32)) {
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||
<< spvOpcodeString(inst->opcode())
|
||||
<< " OutputIndex must be a 32-bit integer.";
|
||||
}
|
||||
|
||||
bool has_element_index = inst->operands().size() > 2;
|
||||
|
||||
// Check type of ElementIndex
|
||||
if (has_element_index) {
|
||||
auto element_index_inst = _.FindDef(inst->GetOperandAs<uint32_t>(2));
|
||||
if (!element_index_inst ||
|
||||
!_.IsIntScalarType(element_index_inst->type_id(), 32)) {
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||
<< spvOpcodeString(inst->opcode())
|
||||
<< " ElementIndex must be a 32-bit integer.";
|
||||
}
|
||||
}
|
||||
|
||||
// Find graph definition
|
||||
size_t inst_num = inst->LineNum() - 1;
|
||||
auto graph_inst = &_.ordered_instructions()[inst_num];
|
||||
while (--inst_num) {
|
||||
graph_inst = &_.ordered_instructions()[inst_num];
|
||||
if (graph_inst->opcode() == spv::Op::OpGraphARM) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Can the OutputIndex be evaluated?
|
||||
// If not, there's nothing more we can validate here.
|
||||
uint64_t output_index;
|
||||
if (!_.EvalConstantValUint64(inst->GetOperandAs<uint32_t>(1),
|
||||
&output_index)) {
|
||||
return SPV_SUCCESS;
|
||||
}
|
||||
|
||||
// Check that the OutputIndex is valid with respect to the graph type
|
||||
auto graph_type_inst = _.FindDef(graph_inst->type_id());
|
||||
size_t graph_type_num_outputs = GraphTypeInstNumOutputs(graph_type_inst);
|
||||
|
||||
if (output_index >= graph_type_num_outputs) {
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||
<< spvOpcodeString(inst->opcode()) << " setting OutputIndex "
|
||||
<< output_index << " but graph only has " << graph_type_num_outputs
|
||||
<< " outputs.";
|
||||
}
|
||||
|
||||
uint32_t graph_type_output_type =
|
||||
GraphTypeInstGetOutputAtIndex(graph_type_inst, output_index);
|
||||
|
||||
if (has_element_index) {
|
||||
// Check ElementIndex is allowed
|
||||
if (!IsTensorArray(_, graph_type_output_type)) {
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||
<< "OpGraphSetOutputARM ElementIndex not allowed when the graph "
|
||||
"output selected by "
|
||||
<< "OutputIndex is not an OpTypeArray or OpTypeRuntimeArray";
|
||||
}
|
||||
|
||||
// Check ElementIndex is in range if it can be evaluated and the output is a
|
||||
// fixed-sized array whose Length can be evaluated
|
||||
uint64_t element_index;
|
||||
if (_.IsArrayType(graph_type_output_type) &&
|
||||
_.EvalConstantValUint64(inst->GetOperandAs<uint32_t>(2),
|
||||
&element_index)) {
|
||||
uint64_t array_length;
|
||||
auto graph_type_output_type_inst = _.FindDef(graph_type_output_type);
|
||||
if (_.EvalConstantValUint64(
|
||||
graph_type_output_type_inst->GetOperandAs<uint32_t>(2),
|
||||
&array_length)) {
|
||||
if (element_index >= array_length) {
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||
<< "OpGraphSetOutputARM ElementIndex out of range. The type "
|
||||
"of the graph output being accessed "
|
||||
<< _.getIdName(graph_type_output_type) << " is an array of "
|
||||
<< array_length << " elements but " << "ElementIndex is "
|
||||
<< element_index;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check Value's type matches with graph type
|
||||
uint32_t value = inst->GetOperandAs<uint32_t>(0);
|
||||
uint32_t value_type = _.FindDef(value)->type_id();
|
||||
if (has_element_index) {
|
||||
uint32_t expected_type = _.GetComponentType(graph_type_output_type);
|
||||
if (value_type != expected_type) {
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||
<< "The type " << _.getIdName(value_type)
|
||||
<< " of Value provided to the graph output instruction "
|
||||
<< _.getIdName(value) << " does not match the component type "
|
||||
<< _.getIdName(expected_type) << " of output " << output_index
|
||||
<< " in the graph type.";
|
||||
}
|
||||
} else {
|
||||
if (value_type != graph_type_output_type) {
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||
<< "The type " << _.getIdName(value_type)
|
||||
<< " of Value provided to the graph output instruction "
|
||||
<< _.getIdName(value) << " does not match the type "
|
||||
<< _.getIdName(graph_type_output_type) << " of output "
|
||||
<< output_index << " in the graph type.";
|
||||
}
|
||||
}
|
||||
return SPV_SUCCESS;
|
||||
}
|
||||
|
||||
bool InputOutputInstructionsHaveDuplicateIndices(
|
||||
ValidationState_t& _, std::deque<const Instruction*>& inout_insts,
|
||||
const Instruction** first_dup) {
|
||||
std::set<std::pair<uint64_t, uint64_t>> inout_element_indices;
|
||||
for (auto const inst : inout_insts) {
|
||||
const bool is_input = inst->opcode() == spv::Op::OpGraphInputARM;
|
||||
bool has_element_index = inst->operands().size() > (is_input ? 3 : 2);
|
||||
uint64_t inout_index;
|
||||
if (!_.EvalConstantValUint64(inst->GetOperandAs<uint32_t>(is_input ? 2 : 1),
|
||||
&inout_index)) {
|
||||
continue;
|
||||
}
|
||||
uint64_t element_index = -1; // -1 means no ElementIndex
|
||||
if (has_element_index) {
|
||||
if (!_.EvalConstantValUint64(
|
||||
inst->GetOperandAs<uint32_t>(is_input ? 3 : 2), &element_index)) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
auto inout_element_pair = std::make_pair(inout_index, element_index);
|
||||
auto inout_noelement_pair = std::make_pair(inout_index, -1);
|
||||
if (inout_element_indices.count(inout_element_pair) ||
|
||||
inout_element_indices.count(inout_noelement_pair)) {
|
||||
*first_dup = inst;
|
||||
return true;
|
||||
}
|
||||
inout_element_indices.insert(inout_element_pair);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
spv_result_t ValidateGraphEnd(ValidationState_t& _, const Instruction* inst) {
|
||||
size_t end_inst_num = inst->LineNum() - 1;
|
||||
|
||||
// Gather OpGraphInputARM and OpGraphSetOutputARM instructions
|
||||
std::deque<const Instruction*> graph_inputs, graph_outputs;
|
||||
size_t in_inst_num = end_inst_num;
|
||||
auto graph_inst = &_.ordered_instructions()[in_inst_num];
|
||||
while (--in_inst_num) {
|
||||
graph_inst = &_.ordered_instructions()[in_inst_num];
|
||||
if (graph_inst->opcode() == spv::Op::OpGraphInputARM) {
|
||||
graph_inputs.push_front(graph_inst);
|
||||
continue;
|
||||
}
|
||||
if (graph_inst->opcode() == spv::Op::OpGraphSetOutputARM) {
|
||||
graph_outputs.push_front(graph_inst);
|
||||
continue;
|
||||
}
|
||||
if (graph_inst->opcode() == spv::Op::OpGraphARM) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
const Instruction* first_dup;
|
||||
|
||||
// Check that there are no duplicate InputIndex and ElementIndex values
|
||||
if (InputOutputInstructionsHaveDuplicateIndices(_, graph_inputs,
|
||||
&first_dup)) {
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, first_dup)
|
||||
<< "Two OpGraphInputARM instructions with the same InputIndex "
|
||||
"must not be part of the same "
|
||||
<< "graph definition unless ElementIndex is present in both with "
|
||||
"different values.";
|
||||
}
|
||||
|
||||
// Check that there are no duplicate OutputIndex and ElementIndex values
|
||||
if (InputOutputInstructionsHaveDuplicateIndices(_, graph_outputs,
|
||||
&first_dup)) {
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, first_dup)
|
||||
<< "Two OpGraphSetOutputARM instructions with the same "
|
||||
"OutputIndex must not be part of the same "
|
||||
<< "graph definition unless ElementIndex is present in both with "
|
||||
"different values.";
|
||||
}
|
||||
|
||||
return SPV_SUCCESS;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
// Validates correctness of graph instructions.
|
||||
spv_result_t GraphPass(ValidationState_t& _, const Instruction* inst) {
|
||||
switch (inst->opcode()) {
|
||||
case spv::Op::OpTypeGraphARM:
|
||||
return ValidateGraphType(_, inst);
|
||||
case spv::Op::OpGraphConstantARM:
|
||||
return ValidateGraphConstant(_, inst);
|
||||
case spv::Op::OpGraphEntryPointARM:
|
||||
return ValidateGraphEntryPoint(_, inst);
|
||||
case spv::Op::OpGraphARM:
|
||||
return ValidateGraph(_, inst);
|
||||
case spv::Op::OpGraphInputARM:
|
||||
return ValidateGraphInput(_, inst);
|
||||
case spv::Op::OpGraphSetOutputARM:
|
||||
return ValidateGraphSetOutput(_, inst);
|
||||
case spv::Op::OpGraphEndARM:
|
||||
return ValidateGraphEnd(_, inst);
|
||||
default:
|
||||
break;
|
||||
}
|
||||
return SPV_SUCCESS;
|
||||
}
|
||||
|
||||
} // namespace val
|
||||
} // namespace spvtools
|
||||
89
3rdparty/spirv-tools/source/val/validate_id.cpp
vendored
89
3rdparty/spirv-tools/source/val/validate_id.cpp
vendored
@@ -115,6 +115,57 @@ spv_result_t CheckIdDefinitionDominateUse(ValidationState_t& _) {
|
||||
return SPV_SUCCESS;
|
||||
}
|
||||
|
||||
bool InstructionCanHaveTypeOperand(const Instruction* inst) {
|
||||
static std::unordered_set<spv::Op> instruction_allow_set{
|
||||
spv::Op::OpSizeOf,
|
||||
spv::Op::OpCooperativeMatrixLengthNV,
|
||||
spv::Op::OpCooperativeMatrixLengthKHR,
|
||||
spv::Op::OpUntypedArrayLengthKHR,
|
||||
spv::Op::OpFunction,
|
||||
spv::Op::OpAsmINTEL,
|
||||
};
|
||||
const auto opcode = inst->opcode();
|
||||
bool type_instruction = spvOpcodeGeneratesType(opcode);
|
||||
bool debug_instruction = spvOpcodeIsDebug(opcode) || inst->IsDebugInfo();
|
||||
bool coop_matrix_spec_constant_op_length =
|
||||
(opcode == spv::Op::OpSpecConstantOp) &&
|
||||
(spv::Op(inst->word(3)) == spv::Op::OpCooperativeMatrixLengthNV ||
|
||||
spv::Op(inst->word(3)) == spv::Op::OpCooperativeMatrixLengthKHR);
|
||||
return type_instruction || debug_instruction || inst->IsNonSemantic() ||
|
||||
spvOpcodeIsDecoration(opcode) || instruction_allow_set.count(opcode) ||
|
||||
spvOpcodeGeneratesUntypedPointer(opcode) ||
|
||||
coop_matrix_spec_constant_op_length;
|
||||
}
|
||||
|
||||
bool InstructionRequiresTypeOperand(const Instruction* inst) {
|
||||
static std::unordered_set<spv::Op> instruction_deny_set{
|
||||
spv::Op::OpExtInst,
|
||||
spv::Op::OpExtInstWithForwardRefsKHR,
|
||||
spv::Op::OpExtInstImport,
|
||||
spv::Op::OpSelectionMerge,
|
||||
spv::Op::OpLoopMerge,
|
||||
spv::Op::OpFunction,
|
||||
spv::Op::OpSizeOf,
|
||||
spv::Op::OpCooperativeMatrixLengthNV,
|
||||
spv::Op::OpCooperativeMatrixLengthKHR,
|
||||
spv::Op::OpPhi,
|
||||
spv::Op::OpUntypedArrayLengthKHR,
|
||||
spv::Op::OpAsmINTEL,
|
||||
};
|
||||
const auto opcode = inst->opcode();
|
||||
bool debug_instruction = spvOpcodeIsDebug(opcode) || inst->IsDebugInfo();
|
||||
bool coop_matrix_spec_constant_op_length =
|
||||
opcode == spv::Op::OpSpecConstantOp &&
|
||||
(spv::Op(inst->word(3)) == spv::Op::OpCooperativeMatrixLengthNV ||
|
||||
spv::Op(inst->word(3)) == spv::Op::OpCooperativeMatrixLengthKHR);
|
||||
|
||||
return !debug_instruction && !inst->IsNonSemantic() &&
|
||||
!spvOpcodeIsDecoration(opcode) && !spvOpcodeIsBranch(opcode) &&
|
||||
!instruction_deny_set.count(opcode) &&
|
||||
!spvOpcodeGeneratesUntypedPointer(opcode) &&
|
||||
!coop_matrix_spec_constant_op_length;
|
||||
}
|
||||
|
||||
// Performs SSA validation on the IDs of an instruction. The
|
||||
// can_have_forward_declared_ids functor should return true if the
|
||||
// instruction operand's ID can be forward referenced.
|
||||
@@ -158,44 +209,14 @@ spv_result_t IdPass(ValidationState_t& _, Instruction* inst) {
|
||||
case SPV_OPERAND_TYPE_MEMORY_SEMANTICS_ID:
|
||||
case SPV_OPERAND_TYPE_SCOPE_ID:
|
||||
if (const auto def = _.FindDef(operand_word)) {
|
||||
const auto opcode = inst->opcode();
|
||||
if (spvOpcodeGeneratesType(def->opcode()) &&
|
||||
!spvOpcodeGeneratesType(opcode) && !spvOpcodeIsDebug(opcode) &&
|
||||
!inst->IsDebugInfo() && !inst->IsNonSemantic() &&
|
||||
!spvOpcodeIsDecoration(opcode) && opcode != spv::Op::OpFunction &&
|
||||
opcode != spv::Op::OpSizeOf &&
|
||||
opcode != spv::Op::OpCooperativeMatrixLengthNV &&
|
||||
opcode != spv::Op::OpCooperativeMatrixLengthKHR &&
|
||||
!spvOpcodeGeneratesUntypedPointer(opcode) &&
|
||||
opcode != spv::Op::OpUntypedArrayLengthKHR &&
|
||||
!(opcode == spv::Op::OpSpecConstantOp &&
|
||||
(spv::Op(inst->word(3)) ==
|
||||
spv::Op::OpCooperativeMatrixLengthNV ||
|
||||
spv::Op(inst->word(3)) ==
|
||||
spv::Op::OpCooperativeMatrixLengthKHR))) {
|
||||
!InstructionCanHaveTypeOperand(inst)) {
|
||||
return _.diag(SPV_ERROR_INVALID_ID, inst)
|
||||
<< "Operand " << _.getIdName(operand_word)
|
||||
<< " cannot be a type";
|
||||
} else if (def->type_id() == 0 && !spvOpcodeGeneratesType(opcode) &&
|
||||
!spvOpcodeIsDebug(opcode) && !inst->IsDebugInfo() &&
|
||||
!inst->IsNonSemantic() && !spvOpcodeIsDecoration(opcode) &&
|
||||
!spvOpcodeIsBranch(opcode) && opcode != spv::Op::OpPhi &&
|
||||
opcode != spv::Op::OpExtInst &&
|
||||
opcode != spv::Op::OpExtInstWithForwardRefsKHR &&
|
||||
opcode != spv::Op::OpExtInstImport &&
|
||||
opcode != spv::Op::OpSelectionMerge &&
|
||||
opcode != spv::Op::OpLoopMerge &&
|
||||
opcode != spv::Op::OpFunction &&
|
||||
opcode != spv::Op::OpSizeOf &&
|
||||
opcode != spv::Op::OpCooperativeMatrixLengthNV &&
|
||||
opcode != spv::Op::OpCooperativeMatrixLengthKHR &&
|
||||
!spvOpcodeGeneratesUntypedPointer(opcode) &&
|
||||
opcode != spv::Op::OpUntypedArrayLengthKHR &&
|
||||
!(opcode == spv::Op::OpSpecConstantOp &&
|
||||
(spv::Op(inst->word(3)) ==
|
||||
spv::Op::OpCooperativeMatrixLengthNV ||
|
||||
spv::Op(inst->word(3)) ==
|
||||
spv::Op::OpCooperativeMatrixLengthKHR))) {
|
||||
} else if (def->type_id() == 0 &&
|
||||
!spvOpcodeGeneratesType(def->opcode()) &&
|
||||
InstructionRequiresTypeOperand(inst)) {
|
||||
return _.diag(SPV_ERROR_INVALID_ID, inst)
|
||||
<< "Operand " << _.getIdName(operand_word)
|
||||
<< " requires a type";
|
||||
|
||||
@@ -464,7 +464,9 @@ spv_result_t ValidateImageOperands(ValidationState_t& _,
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||
<< _.VkErrorID(10213)
|
||||
<< "Image Operand Offset can only be used with "
|
||||
"OpImage*Gather operations";
|
||||
"OpImage*Gather operations."
|
||||
<< _.MissingFeature("maintenance8 feature",
|
||||
"--allow-offset-texture-operand", false);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -195,7 +195,8 @@ spv_result_t CheckRequiredCapabilities(ValidationState_t& state,
|
||||
// registers a capability with the module *before* checking capabilities.
|
||||
// So in the case of an OpCapability instruction, don't bother checking
|
||||
// enablement by another capability.
|
||||
if (inst->opcode() != spv::Op::OpCapability) {
|
||||
if (inst->opcode() != spv::Op::OpCapability &&
|
||||
inst->opcode() != spv::Op::OpConditionalCapabilityINTEL) {
|
||||
const bool enabled_by_cap =
|
||||
state.HasAnyOfCapabilities(enabling_capabilities);
|
||||
if (!enabling_capabilities.empty() && !enabled_by_cap) {
|
||||
@@ -461,10 +462,13 @@ spv_result_t CheckIfKnownExtension(ValidationState_t& _,
|
||||
|
||||
spv_result_t InstructionPass(ValidationState_t& _, const Instruction* inst) {
|
||||
const spv::Op opcode = inst->opcode();
|
||||
if (opcode == spv::Op::OpExtension) {
|
||||
if (opcode == spv::Op::OpExtension ||
|
||||
opcode == spv::Op::OpConditionalExtensionINTEL) {
|
||||
CheckIfKnownExtension(_, inst);
|
||||
} else if (opcode == spv::Op::OpCapability) {
|
||||
_.RegisterCapability(inst->GetOperandAs<spv::Capability>(0));
|
||||
} else if (opcode == spv::Op::OpConditionalCapabilityINTEL) {
|
||||
_.RegisterCapability(inst->GetOperandAs<spv::Capability>(1));
|
||||
} else if (opcode == spv::Op::OpMemoryModel) {
|
||||
if (_.has_memory_model_specified()) {
|
||||
return _.diag(SPV_ERROR_INVALID_LAYOUT, inst)
|
||||
|
||||
@@ -166,20 +166,17 @@ spv_result_t NumConsumedLocations(ValidationState_t& _, const Instruction* type,
|
||||
}
|
||||
break;
|
||||
case spv::Op::OpTypeMatrix:
|
||||
// Matrices consume locations equivalent to arrays.
|
||||
if (auto error = NumConsumedLocations(
|
||||
_, _.FindDef(type->GetOperandAs<uint32_t>(1)), num_locations)) {
|
||||
return error;
|
||||
}
|
||||
// Matrices consume locations equal to the underlying vector type for
|
||||
// each column.
|
||||
NumConsumedLocations(_, _.FindDef(type->GetOperandAs<uint32_t>(1)),
|
||||
num_locations);
|
||||
*num_locations *= type->GetOperandAs<uint32_t>(2);
|
||||
break;
|
||||
case spv::Op::OpTypeArray: {
|
||||
// Arrays consume locations equal to the underlying type times the number
|
||||
// of elements in the vector.
|
||||
if (auto error = NumConsumedLocations(
|
||||
_, _.FindDef(type->GetOperandAs<uint32_t>(1)), num_locations)) {
|
||||
return error;
|
||||
}
|
||||
NumConsumedLocations(_, _.FindDef(type->GetOperandAs<uint32_t>(1)),
|
||||
num_locations);
|
||||
bool is_int = false;
|
||||
bool is_const = false;
|
||||
uint32_t value = 0;
|
||||
@@ -249,31 +246,10 @@ uint32_t NumConsumedComponents(ValidationState_t& _, const Instruction* type) {
|
||||
NumConsumedComponents(_, _.FindDef(type->GetOperandAs<uint32_t>(1)));
|
||||
num_components *= type->GetOperandAs<uint32_t>(2);
|
||||
break;
|
||||
case spv::Op::OpTypeMatrix:
|
||||
// Matrices consume all components of the location.
|
||||
// Round up to next multiple of 4.
|
||||
num_components =
|
||||
NumConsumedComponents(_, _.FindDef(type->GetOperandAs<uint32_t>(1)));
|
||||
num_components *= type->GetOperandAs<uint32_t>(2);
|
||||
num_components = ((num_components + 3) / 4) * 4;
|
||||
break;
|
||||
case spv::Op::OpTypeArray: {
|
||||
// Arrays consume all components of the location.
|
||||
// Round up to next multiple of 4.
|
||||
num_components =
|
||||
NumConsumedComponents(_, _.FindDef(type->GetOperandAs<uint32_t>(1)));
|
||||
|
||||
bool is_int = false;
|
||||
bool is_const = false;
|
||||
uint32_t value = 0;
|
||||
// Attempt to evaluate the number of array elements.
|
||||
std::tie(is_int, is_const, value) =
|
||||
_.EvalInt32IfConst(type->GetOperandAs<uint32_t>(2));
|
||||
if (is_int && is_const) num_components *= value;
|
||||
|
||||
num_components = ((num_components + 3) / 4) * 4;
|
||||
return num_components;
|
||||
}
|
||||
case spv::Op::OpTypeArray:
|
||||
// Skip the array.
|
||||
return NumConsumedComponents(_,
|
||||
_.FindDef(type->GetOperandAs<uint32_t>(1)));
|
||||
case spv::Op::OpTypePointer:
|
||||
if (_.addressing_model() ==
|
||||
spv::AddressingModel::PhysicalStorageBuffer64 &&
|
||||
@@ -356,10 +332,9 @@ spv_result_t GetLocationsForVariable(
|
||||
}
|
||||
}
|
||||
|
||||
// Vulkan 15.1.3 (Interface Matching): Tessellation control and mesh
|
||||
// per-vertex outputs and tessellation control, evaluation and geometry
|
||||
// per-vertex inputs have a layer of arraying that is not included in
|
||||
// interface matching.
|
||||
// Vulkan 14.1.3: Tessellation control and mesh per-vertex outputs and
|
||||
// tessellation control, evaluation and geometry per-vertex inputs have a
|
||||
// layer of arraying that is not included in interface matching.
|
||||
bool is_arrayed = false;
|
||||
switch (entry_point->GetOperandAs<spv::ExecutionModel>(0)) {
|
||||
case spv::ExecutionModel::TessellationControl:
|
||||
@@ -413,33 +388,51 @@ spv_result_t GetLocationsForVariable(
|
||||
|
||||
const std::string storage_class = is_output ? "output" : "input";
|
||||
if (has_location) {
|
||||
auto sub_type = type;
|
||||
bool is_int = false;
|
||||
bool is_const = false;
|
||||
uint32_t array_size = 1;
|
||||
// If the variable is still arrayed, mark the locations/components per
|
||||
// index.
|
||||
if (type->opcode() == spv::Op::OpTypeArray) {
|
||||
// Determine the array size if possible and get the element type.
|
||||
std::tie(is_int, is_const, array_size) =
|
||||
_.EvalInt32IfConst(type->GetOperandAs<uint32_t>(2));
|
||||
if (!is_int || !is_const) array_size = 1;
|
||||
auto sub_type_id = type->GetOperandAs<uint32_t>(1);
|
||||
sub_type = _.FindDef(sub_type_id);
|
||||
}
|
||||
|
||||
uint32_t num_locations = 0;
|
||||
if (auto error = NumConsumedLocations(_, type, &num_locations))
|
||||
if (auto error = NumConsumedLocations(_, sub_type, &num_locations))
|
||||
return error;
|
||||
uint32_t num_components = NumConsumedComponents(_, type);
|
||||
uint32_t num_components = NumConsumedComponents(_, sub_type);
|
||||
|
||||
uint32_t start = location * 4;
|
||||
uint32_t end = (location + num_locations) * 4;
|
||||
if (num_components % 4 != 0) {
|
||||
start += component;
|
||||
end = start + num_components;
|
||||
}
|
||||
for (uint32_t array_idx = 0; array_idx < array_size; ++array_idx) {
|
||||
uint32_t array_location = location + (num_locations * array_idx);
|
||||
uint32_t start = array_location * 4;
|
||||
if (kMaxLocations <= start) {
|
||||
// Too many locations, give up.
|
||||
break;
|
||||
}
|
||||
|
||||
if (kMaxLocations <= start) {
|
||||
// Too many locations, give up.
|
||||
return SPV_SUCCESS;
|
||||
}
|
||||
uint32_t end = (array_location + num_locations) * 4;
|
||||
if (num_components != 0) {
|
||||
start += component;
|
||||
end = array_location * 4 + component + num_components;
|
||||
}
|
||||
|
||||
auto locs = locations;
|
||||
if (has_index && index == 1) locs = output_index1_locations;
|
||||
auto locs = locations;
|
||||
if (has_index && index == 1) locs = output_index1_locations;
|
||||
|
||||
for (uint32_t i = start; i < end; ++i) {
|
||||
if (!locs->insert(i).second) {
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, entry_point)
|
||||
<< (is_output ? _.VkErrorID(8722) : _.VkErrorID(8721))
|
||||
<< "Entry-point has conflicting " << storage_class
|
||||
<< " location assignment at location " << i / 4 << ", component "
|
||||
<< i % 4;
|
||||
for (uint32_t i = start; i < end; ++i) {
|
||||
if (!locs->insert(i).second) {
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, entry_point)
|
||||
<< (is_output ? _.VkErrorID(8722) : _.VkErrorID(8721))
|
||||
<< "Entry-point has conflicting " << storage_class
|
||||
<< " location assignment at location " << i / 4
|
||||
<< ", component " << i % 4;
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
@@ -498,19 +491,38 @@ spv_result_t GetLocationsForVariable(
|
||||
continue;
|
||||
}
|
||||
|
||||
uint32_t end = (location + num_locations) * 4;
|
||||
if (num_components % 4 != 0) {
|
||||
start += component;
|
||||
end = location * 4 + component + num_components;
|
||||
}
|
||||
|
||||
for (uint32_t l = start; l < end; ++l) {
|
||||
if (!locations->insert(l).second) {
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, entry_point)
|
||||
<< (is_output ? _.VkErrorID(8722) : _.VkErrorID(8721))
|
||||
<< "Entry-point has conflicting " << storage_class
|
||||
<< " location assignment at location " << l / 4
|
||||
<< ", component " << l % 4;
|
||||
if (member->opcode() == spv::Op::OpTypeArray && num_components >= 1 &&
|
||||
num_components < 4) {
|
||||
// When an array has an element that takes less than a location in
|
||||
// size, calculate the used locations in a strided manner.
|
||||
for (uint32_t l = location; l < num_locations + location; ++l) {
|
||||
for (uint32_t c = component; c < component + num_components; ++c) {
|
||||
uint32_t check = 4 * l + c;
|
||||
if (!locations->insert(check).second) {
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, entry_point)
|
||||
<< (is_output ? _.VkErrorID(8722) : _.VkErrorID(8721))
|
||||
<< "Entry-point has conflicting " << storage_class
|
||||
<< " location assignment at location " << l
|
||||
<< ", component " << c;
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// TODO: There is a hole here is the member is an array of 3- or
|
||||
// 4-element vectors of 64-bit types.
|
||||
uint32_t end = (location + num_locations) * 4;
|
||||
if (num_components != 0) {
|
||||
start += component;
|
||||
end = location * 4 + component + num_components;
|
||||
}
|
||||
for (uint32_t l = start; l < end; ++l) {
|
||||
if (!locations->insert(l).second) {
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, entry_point)
|
||||
<< (is_output ? _.VkErrorID(8722) : _.VkErrorID(8721))
|
||||
<< "Entry-point has conflicting " << storage_class
|
||||
<< " location assignment at location " << l / 4
|
||||
<< ", component " << l % 4;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -69,12 +69,11 @@ spv_result_t InvalidTypePass(ValidationState_t& _, const Instruction* inst) {
|
||||
case spv::Op::OpGroupNonUniformFMul:
|
||||
case spv::Op::OpGroupNonUniformFMin: {
|
||||
const uint32_t result_type = inst->type_id();
|
||||
if (_.IsBfloat16ScalarType(result_type) ||
|
||||
_.IsBfloat16VectorType(result_type)) {
|
||||
if (_.IsBfloat16Type(result_type)) {
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||
<< spvOpcodeString(opcode) << " doesn't support BFloat16 type.";
|
||||
}
|
||||
if (_.IsFP8ScalarOrVectorType(result_type)) {
|
||||
if (_.IsFP8Type(result_type)) {
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||
<< spvOpcodeString(opcode)
|
||||
<< " doesn't support FP8 E4M3/E5M2 types.";
|
||||
@@ -103,12 +102,11 @@ spv_result_t InvalidTypePass(ValidationState_t& _, const Instruction* inst) {
|
||||
case spv::Op::OpIsNormal:
|
||||
case spv::Op::OpSignBitSet: {
|
||||
const uint32_t operand_type = _.GetOperandTypeId(inst, 2);
|
||||
if (_.IsBfloat16ScalarType(operand_type) ||
|
||||
_.IsBfloat16VectorType(operand_type)) {
|
||||
if (_.IsBfloat16Type(operand_type)) {
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||
<< spvOpcodeString(opcode) << " doesn't support BFloat16 type.";
|
||||
}
|
||||
if (_.IsFP8ScalarOrVectorType(operand_type)) {
|
||||
if (_.IsFP8Type(operand_type)) {
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||
<< spvOpcodeString(opcode)
|
||||
<< " doesn't support FP8 E4M3/E5M2 types.";
|
||||
@@ -118,12 +116,11 @@ spv_result_t InvalidTypePass(ValidationState_t& _, const Instruction* inst) {
|
||||
|
||||
case spv::Op::OpGroupNonUniformAllEqual: {
|
||||
const auto value_type = _.GetOperandTypeId(inst, 3);
|
||||
if (_.IsBfloat16ScalarType(value_type) ||
|
||||
_.IsBfloat16VectorType(value_type)) {
|
||||
if (_.IsBfloat16Type(value_type)) {
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||
<< spvOpcodeString(opcode) << " doesn't support BFloat16 type.";
|
||||
}
|
||||
if (_.IsFP8ScalarOrVectorType(value_type)) {
|
||||
if (_.IsFP8Type(value_type)) {
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||
<< spvOpcodeString(opcode)
|
||||
<< " doesn't support FP8 E4M3/E5M2 types.";
|
||||
@@ -140,12 +137,12 @@ spv_result_t InvalidTypePass(ValidationState_t& _, const Instruction* inst) {
|
||||
uint32_t res_component_type = 0;
|
||||
if (_.GetMatrixTypeInfo(result_type, &res_num_rows, &res_num_cols,
|
||||
&res_col_type, &res_component_type)) {
|
||||
if (_.IsBfloat16ScalarType(res_component_type)) {
|
||||
if (_.IsBfloat16Type(res_component_type)) {
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||
<< spvOpcodeString(opcode)
|
||||
<< " doesn't support BFloat16 type.";
|
||||
}
|
||||
if (_.IsFP8ScalarOrVectorType(res_component_type)) {
|
||||
if (_.IsFP8Type(res_component_type)) {
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||
<< spvOpcodeString(opcode)
|
||||
<< " doesn't support FP8 E4M3/E5M2 types.";
|
||||
|
||||
@@ -342,13 +342,84 @@ spv_result_t FunctionScopedInstructions(ValidationState_t& _,
|
||||
break;
|
||||
}
|
||||
} else {
|
||||
return _.diag(SPV_ERROR_INVALID_LAYOUT, inst)
|
||||
<< spvOpcodeString(opcode)
|
||||
<< " cannot appear in a function declaration";
|
||||
_.ProgressToNextLayoutSectionOrder();
|
||||
// All function sections have been processed. Recursively call
|
||||
// ModuleLayoutPass to process the next section of the module
|
||||
return ModuleLayoutPass(_, inst);
|
||||
}
|
||||
return SPV_SUCCESS;
|
||||
}
|
||||
|
||||
spv_result_t GraphScopedInstructions(ValidationState_t& _,
|
||||
const Instruction* inst, spv::Op opcode) {
|
||||
if (_.IsOpcodeInCurrentLayoutSection(opcode)) {
|
||||
switch (opcode) {
|
||||
case spv::Op::OpGraphARM: {
|
||||
if (_.graph_definition_region() > kGraphDefinitionOutside) {
|
||||
return _.diag(SPV_ERROR_INVALID_LAYOUT, inst)
|
||||
<< "Cannot define a graph in a graph";
|
||||
}
|
||||
_.SetGraphDefinitionRegion(kGraphDefinitionBegin);
|
||||
} break;
|
||||
case spv::Op::OpGraphInputARM: {
|
||||
if ((_.graph_definition_region() != kGraphDefinitionBegin) &&
|
||||
(_.graph_definition_region() != kGraphDefinitionInputs)) {
|
||||
return _.diag(SPV_ERROR_INVALID_LAYOUT, inst)
|
||||
<< "OpGraphInputARM"
|
||||
<< " must immediately follow an OpGraphARM or OpGraphInputARM "
|
||||
"instruction.";
|
||||
}
|
||||
_.SetGraphDefinitionRegion(kGraphDefinitionInputs);
|
||||
} break;
|
||||
case spv::Op::OpGraphSetOutputARM: {
|
||||
if ((_.graph_definition_region() != kGraphDefinitionBegin) &&
|
||||
(_.graph_definition_region() != kGraphDefinitionInputs) &&
|
||||
(_.graph_definition_region() != kGraphDefinitionBody) &&
|
||||
(_.graph_definition_region() != kGraphDefinitionOutputs)) {
|
||||
return _.diag(SPV_ERROR_INVALID_LAYOUT, inst)
|
||||
<< "Op" << spvOpcodeString(opcode)
|
||||
<< " must immediately precede an OpGraphEndARM or "
|
||||
"OpGraphSetOutputARM instruction.";
|
||||
}
|
||||
_.SetGraphDefinitionRegion(kGraphDefinitionOutputs);
|
||||
} break;
|
||||
case spv::Op::OpGraphEndARM: {
|
||||
if (_.graph_definition_region() != kGraphDefinitionOutputs) {
|
||||
return _.diag(SPV_ERROR_INVALID_LAYOUT, inst)
|
||||
<< spvOpcodeString(opcode)
|
||||
<< " must be preceded by at least one OpGraphSetOutputARM "
|
||||
"instruction";
|
||||
}
|
||||
_.SetGraphDefinitionRegion(kGraphDefinitionOutside);
|
||||
} break;
|
||||
case spv::Op::OpGraphEntryPointARM:
|
||||
if (_.graph_definition_region() != kGraphDefinitionOutside) {
|
||||
return _.diag(SPV_ERROR_INVALID_LAYOUT, inst)
|
||||
<< spvOpcodeString(opcode)
|
||||
<< " cannot appear in the definition of a graph";
|
||||
}
|
||||
break;
|
||||
default:
|
||||
if (_.graph_definition_region() == kGraphDefinitionOutside) {
|
||||
return _.diag(SPV_ERROR_INVALID_LAYOUT, inst)
|
||||
<< "Op" << spvOpcodeString(opcode)
|
||||
<< " must appear in a graph body";
|
||||
}
|
||||
if (_.graph_definition_region() == kGraphDefinitionOutputs) {
|
||||
return _.diag(SPV_ERROR_INVALID_LAYOUT, inst)
|
||||
<< spvOpcodeString(opcode)
|
||||
<< " cannot appear after a graph output instruction";
|
||||
}
|
||||
_.SetGraphDefinitionRegion(kGraphDefinitionBody);
|
||||
break;
|
||||
}
|
||||
} else {
|
||||
return _.diag(SPV_ERROR_INVALID_LAYOUT, inst)
|
||||
<< "Op" << spvOpcodeString(opcode)
|
||||
<< " cannot appear in the graph definitions section";
|
||||
}
|
||||
return SPV_SUCCESS;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
// TODO(umar): Check linkage capabilities for function declarations
|
||||
@@ -379,6 +450,11 @@ spv_result_t ModuleLayoutPass(ValidationState_t& _, const Instruction* inst) {
|
||||
return error;
|
||||
}
|
||||
break;
|
||||
case kLayoutGraphDefinitions:
|
||||
if (auto error = GraphScopedInstructions(_, inst, opcode)) {
|
||||
return error;
|
||||
}
|
||||
break;
|
||||
}
|
||||
return SPV_SUCCESS;
|
||||
}
|
||||
|
||||
142
3rdparty/spirv-tools/source/val/validate_memory.cpp
vendored
142
3rdparty/spirv-tools/source/val/validate_memory.cpp
vendored
@@ -196,10 +196,10 @@ bool ContainsInvalidBool(ValidationState_t& _, const Instruction* storage,
|
||||
return false;
|
||||
}
|
||||
|
||||
std::pair<spv::StorageClass, spv::StorageClass> GetStorageClass(
|
||||
ValidationState_t& _, const Instruction* inst) {
|
||||
spv::StorageClass dst_sc = spv::StorageClass::Max;
|
||||
spv::StorageClass src_sc = spv::StorageClass::Max;
|
||||
std::pair<Instruction*, Instruction*> GetPointerTypes(ValidationState_t& _,
|
||||
const Instruction* inst) {
|
||||
Instruction* dst_pointer_type = nullptr;
|
||||
Instruction* src_pointer_type = nullptr;
|
||||
switch (inst->opcode()) {
|
||||
case spv::Op::OpCooperativeMatrixLoadNV:
|
||||
case spv::Op::OpCooperativeMatrixLoadTensorNV:
|
||||
@@ -207,8 +207,7 @@ std::pair<spv::StorageClass, spv::StorageClass> GetStorageClass(
|
||||
case spv::Op::OpCooperativeVectorLoadNV:
|
||||
case spv::Op::OpLoad: {
|
||||
auto load_pointer = _.FindDef(inst->GetOperandAs<uint32_t>(2));
|
||||
auto load_pointer_type = _.FindDef(load_pointer->type_id());
|
||||
dst_sc = load_pointer_type->GetOperandAs<spv::StorageClass>(1);
|
||||
dst_pointer_type = _.FindDef(load_pointer->type_id());
|
||||
break;
|
||||
}
|
||||
case spv::Op::OpCooperativeMatrixStoreNV:
|
||||
@@ -217,25 +216,23 @@ std::pair<spv::StorageClass, spv::StorageClass> GetStorageClass(
|
||||
case spv::Op::OpCooperativeVectorStoreNV:
|
||||
case spv::Op::OpStore: {
|
||||
auto store_pointer = _.FindDef(inst->GetOperandAs<uint32_t>(0));
|
||||
auto store_pointer_type = _.FindDef(store_pointer->type_id());
|
||||
dst_sc = store_pointer_type->GetOperandAs<spv::StorageClass>(1);
|
||||
dst_pointer_type = _.FindDef(store_pointer->type_id());
|
||||
break;
|
||||
}
|
||||
// Spec: "Matching Storage Class is not required"
|
||||
case spv::Op::OpCopyMemory:
|
||||
case spv::Op::OpCopyMemorySized: {
|
||||
auto dst = _.FindDef(inst->GetOperandAs<uint32_t>(0));
|
||||
auto dst_type = _.FindDef(dst->type_id());
|
||||
dst_sc = dst_type->GetOperandAs<spv::StorageClass>(1);
|
||||
auto src = _.FindDef(inst->GetOperandAs<uint32_t>(1));
|
||||
auto src_type = _.FindDef(src->type_id());
|
||||
src_sc = src_type->GetOperandAs<spv::StorageClass>(1);
|
||||
auto dst_pointer = _.FindDef(inst->GetOperandAs<uint32_t>(0));
|
||||
dst_pointer_type = _.FindDef(dst_pointer->type_id());
|
||||
auto src_pointer = _.FindDef(inst->GetOperandAs<uint32_t>(1));
|
||||
src_pointer_type = _.FindDef(src_pointer->type_id());
|
||||
break;
|
||||
}
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
||||
return std::make_pair(dst_sc, src_sc);
|
||||
return std::make_pair(dst_pointer_type, src_pointer_type);
|
||||
}
|
||||
|
||||
// Returns the number of instruction words taken up by a memory access
|
||||
@@ -288,8 +285,17 @@ bool DoesStructContainRTA(const ValidationState_t& _, const Instruction* inst) {
|
||||
|
||||
spv_result_t CheckMemoryAccess(ValidationState_t& _, const Instruction* inst,
|
||||
uint32_t index) {
|
||||
spv::StorageClass dst_sc, src_sc;
|
||||
std::tie(dst_sc, src_sc) = GetStorageClass(_, inst);
|
||||
Instruction* dst_pointer_type = nullptr;
|
||||
Instruction* src_pointer_type = nullptr; // only used for OpCopyMemory
|
||||
std::tie(dst_pointer_type, src_pointer_type) = GetPointerTypes(_, inst);
|
||||
|
||||
const spv::StorageClass dst_sc =
|
||||
dst_pointer_type ? dst_pointer_type->GetOperandAs<spv::StorageClass>(1)
|
||||
: spv::StorageClass::Max;
|
||||
const spv::StorageClass src_sc =
|
||||
src_pointer_type ? src_pointer_type->GetOperandAs<spv::StorageClass>(1)
|
||||
: spv::StorageClass::Max;
|
||||
|
||||
if (inst->operands().size() <= index) {
|
||||
// Cases where lack of some operand is invalid
|
||||
if (src_sc == spv::StorageClass::PhysicalStorageBuffer ||
|
||||
@@ -390,6 +396,23 @@ spv_result_t CheckMemoryAccess(ValidationState_t& _, const Instruction* inst,
|
||||
<< "Memory accesses Aligned operand value " << aligned_value
|
||||
<< " is not a power of two.";
|
||||
}
|
||||
|
||||
uint32_t largest_scalar = 0;
|
||||
if (dst_sc == spv::StorageClass::PhysicalStorageBuffer) {
|
||||
largest_scalar =
|
||||
_.GetLargestScalarType(dst_pointer_type->GetOperandAs<uint32_t>(2));
|
||||
}
|
||||
if (src_sc == spv::StorageClass::PhysicalStorageBuffer) {
|
||||
largest_scalar = std::max(
|
||||
largest_scalar,
|
||||
_.GetLargestScalarType(src_pointer_type->GetOperandAs<uint32_t>(2)));
|
||||
}
|
||||
if (aligned_value < largest_scalar) {
|
||||
return _.diag(SPV_ERROR_INVALID_ID, inst)
|
||||
<< _.VkErrorID(6314) << "Memory accesses Aligned operand value "
|
||||
<< aligned_value << " is too small, the largest scalar type is "
|
||||
<< largest_scalar << " bytes.";
|
||||
}
|
||||
}
|
||||
|
||||
return SPV_SUCCESS;
|
||||
@@ -435,6 +458,7 @@ spv_result_t ValidateVariable(ValidationState_t& _, const Instruction* inst) {
|
||||
}
|
||||
if (spvIsVulkanEnv(_.context()->target_env)) {
|
||||
return _.diag(SPV_ERROR_INVALID_ID, inst)
|
||||
<< _.VkErrorID(11167)
|
||||
<< "Vulkan requires that data type be specified";
|
||||
}
|
||||
}
|
||||
@@ -1555,6 +1579,60 @@ spv_result_t ValidateAccessChain(ValidationState_t& _,
|
||||
return _.diag(SPV_ERROR_INVALID_ID, inst)
|
||||
<< "Base type must be a non-pointer type";
|
||||
}
|
||||
|
||||
const auto ContainsBlock = [&_](const Instruction* type_inst) {
|
||||
if (type_inst->opcode() == spv::Op::OpTypeStruct) {
|
||||
if (_.HasDecoration(type_inst->id(), spv::Decoration::Block) ||
|
||||
_.HasDecoration(type_inst->id(), spv::Decoration::BufferBlock)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
};
|
||||
|
||||
// Block (and BufferBlock) arrays cannot be reinterpreted via untyped access
|
||||
// chains.
|
||||
const bool base_type_block_array =
|
||||
base_type->opcode() == spv::Op::OpTypeArray &&
|
||||
_.ContainsType(base_type->id(), ContainsBlock,
|
||||
/* traverse_all_types = */ false);
|
||||
|
||||
const auto base_index = untyped_pointer ? 3 : 2;
|
||||
const auto base_id = inst->GetOperandAs<uint32_t>(base_index);
|
||||
auto base = _.FindDef(base_id);
|
||||
// Strictly speaking this misses trivial access chains and function
|
||||
// parameter chasing, but that would be a significant complication in the
|
||||
// traversal.
|
||||
while (base->opcode() == spv::Op::OpCopyObject) {
|
||||
base = _.FindDef(base->GetOperandAs<uint32_t>(2));
|
||||
}
|
||||
const Instruction* base_data_type = nullptr;
|
||||
if (base->opcode() == spv::Op::OpVariable) {
|
||||
const auto ptr_type = _.FindDef(base->type_id());
|
||||
base_data_type = _.FindDef(ptr_type->GetOperandAs<uint32_t>(2));
|
||||
} else if (base->opcode() == spv::Op::OpUntypedVariableKHR) {
|
||||
if (base->operands().size() > 3) {
|
||||
base_data_type = _.FindDef(base->GetOperandAs<uint32_t>(3));
|
||||
}
|
||||
}
|
||||
|
||||
if (base_data_type) {
|
||||
const bool base_block_array =
|
||||
base_data_type->opcode() == spv::Op::OpTypeArray &&
|
||||
_.ContainsType(base_data_type->id(), ContainsBlock,
|
||||
/* traverse_all_types = */ false);
|
||||
|
||||
if (base_type_block_array != base_block_array) {
|
||||
return _.diag(SPV_ERROR_INVALID_ID, inst)
|
||||
<< "Both Base Type and Base must be Block or BufferBlock arrays "
|
||||
"or neither can be";
|
||||
} else if (base_type_block_array && base_block_array &&
|
||||
base_type->id() != base_data_type->id()) {
|
||||
return _.diag(SPV_ERROR_INVALID_ID, inst)
|
||||
<< "If Base or Base Type is a Block or BufferBlock array, the "
|
||||
"other must also be the same array";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Base must be a pointer, pointing to the base of a composite object.
|
||||
@@ -1845,14 +1923,34 @@ spv_result_t ValidatePtrAccessChain(ValidationState_t& _,
|
||||
|
||||
const bool untyped_pointer = spvOpcodeGeneratesUntypedPointer(inst->opcode());
|
||||
|
||||
const auto base_id = inst->GetOperandAs<uint32_t>(2);
|
||||
const auto base = _.FindDef(base_id);
|
||||
const auto base_type = untyped_pointer
|
||||
? _.FindDef(inst->GetOperandAs<uint32_t>(2))
|
||||
: _.FindDef(base->type_id());
|
||||
const auto base_idx = untyped_pointer ? 3 : 2;
|
||||
const auto base = _.FindDef(inst->GetOperandAs<uint32_t>(base_idx));
|
||||
const auto base_type = _.FindDef(base->type_id());
|
||||
const auto base_type_storage_class =
|
||||
base_type->GetOperandAs<spv::StorageClass>(1);
|
||||
|
||||
const auto element_idx = untyped_pointer ? 4 : 3;
|
||||
const auto element = _.FindDef(inst->GetOperandAs<uint32_t>(element_idx));
|
||||
const auto element_type = _.FindDef(element->type_id());
|
||||
if (!element_type || element_type->opcode() != spv::Op::OpTypeInt) {
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Element must be an integer";
|
||||
}
|
||||
uint64_t element_val = 0;
|
||||
if (_.EvalConstantValUint64(element->id(), &element_val)) {
|
||||
if (element_val != 0) {
|
||||
const auto interp_type =
|
||||
untyped_pointer ? _.FindDef(inst->GetOperandAs<uint32_t>(2))
|
||||
: _.FindDef(base_type->GetOperandAs<uint32_t>(2));
|
||||
if (interp_type->opcode() == spv::Op::OpTypeStruct &&
|
||||
(_.HasDecoration(interp_type->id(), spv::Decoration::Block) ||
|
||||
_.HasDecoration(interp_type->id(), spv::Decoration::BufferBlock))) {
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||
<< "Element must be 0 if the interpretation type is a Block- or "
|
||||
"BufferBlock-decorated structure";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (_.HasCapability(spv::Capability::Shader) &&
|
||||
(base_type_storage_class == spv::StorageClass::Uniform ||
|
||||
base_type_storage_class == spv::StorageClass::StorageBuffer ||
|
||||
|
||||
@@ -32,6 +32,9 @@ spv_result_t ValidateMemorySemantics(ValidationState_t& _,
|
||||
uint32_t value = 0;
|
||||
std::tie(is_int32, is_const_int32, value) = _.EvalInt32IfConst(id);
|
||||
|
||||
const bool is_vulkan = spvIsVulkanEnv(_.context()->target_env) ||
|
||||
_.memory_model() == spv::MemoryModel::VulkanKHR;
|
||||
|
||||
if (!is_int32) {
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||
<< spvOpcodeString(opcode)
|
||||
@@ -56,6 +59,21 @@ spv_result_t ValidateMemorySemantics(ValidationState_t& _,
|
||||
return SPV_SUCCESS;
|
||||
}
|
||||
|
||||
if (value & uint32_t(spv::MemorySemanticsMask::UniformMemory) &&
|
||||
!_.HasCapability(spv::Capability::Shader)) {
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||
<< spvOpcodeString(opcode)
|
||||
<< ": Memory Semantics UniformMemory requires capability Shader";
|
||||
}
|
||||
|
||||
if (value & uint32_t(spv::MemorySemanticsMask::OutputMemoryKHR) &&
|
||||
!_.HasCapability(spv::Capability::VulkanMemoryModel)) {
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||
<< spvOpcodeString(opcode)
|
||||
<< ": Memory Semantics OutputMemoryKHR requires capability "
|
||||
<< "VulkanMemoryModelKHR";
|
||||
}
|
||||
|
||||
const size_t num_memory_order_set_bits = spvtools::utils::CountSetBits(
|
||||
value & uint32_t(spv::MemorySemanticsMask::Acquire |
|
||||
spv::MemorySemanticsMask::Release |
|
||||
@@ -64,197 +82,207 @@ spv_result_t ValidateMemorySemantics(ValidationState_t& _,
|
||||
|
||||
if (num_memory_order_set_bits > 1) {
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||
<< spvOpcodeString(opcode)
|
||||
<< ": Memory Semantics can have at most one of the following "
|
||||
"bits set: Acquire, Release, AcquireRelease or "
|
||||
"SequentiallyConsistent";
|
||||
<< _.VkErrorID(10865) << spvOpcodeString(opcode)
|
||||
<< ": Memory Semantics must have at most one non-relaxed "
|
||||
"memory order bit set";
|
||||
}
|
||||
|
||||
if (_.memory_model() == spv::MemoryModel::VulkanKHR &&
|
||||
value & uint32_t(spv::MemorySemanticsMask::SequentiallyConsistent)) {
|
||||
if (is_vulkan &&
|
||||
(value & uint32_t(spv::MemorySemanticsMask::SequentiallyConsistent))) {
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||
<< "SequentiallyConsistent memory "
|
||||
"semantics cannot be used with "
|
||||
"the VulkanKHR memory model.";
|
||||
<< _.VkErrorID(10866) << spvOpcodeString(opcode)
|
||||
<< ": Memory Semantics with SequentiallyConsistent memory order "
|
||||
"must not be used in the Vulkan API";
|
||||
}
|
||||
|
||||
if (value & uint32_t(spv::MemorySemanticsMask::MakeAvailableKHR) &&
|
||||
!_.HasCapability(spv::Capability::VulkanMemoryModelKHR)) {
|
||||
if ((opcode == spv::Op::OpAtomicStore ||
|
||||
opcode == spv::Op::OpAtomicFlagClear) &&
|
||||
(value & uint32_t(spv::MemorySemanticsMask::Acquire) ||
|
||||
value & uint32_t(spv::MemorySemanticsMask::AcquireRelease))) {
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||
<< spvOpcodeString(opcode)
|
||||
<< ": Memory Semantics MakeAvailableKHR requires capability "
|
||||
<< "VulkanMemoryModelKHR";
|
||||
<< _.VkErrorID(10867) << spvOpcodeString(opcode)
|
||||
<< ": MemorySemantics must not use Acquire or AcquireRelease "
|
||||
"memory order with "
|
||||
<< spvOpcodeString(opcode);
|
||||
}
|
||||
|
||||
if (value & uint32_t(spv::MemorySemanticsMask::MakeVisibleKHR) &&
|
||||
!_.HasCapability(spv::Capability::VulkanMemoryModelKHR)) {
|
||||
if (opcode == spv::Op::OpAtomicLoad &&
|
||||
(value & uint32_t(spv::MemorySemanticsMask::Release) ||
|
||||
value & uint32_t(spv::MemorySemanticsMask::AcquireRelease))) {
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||
<< spvOpcodeString(opcode)
|
||||
<< ": Memory Semantics MakeVisibleKHR requires capability "
|
||||
<< "VulkanMemoryModelKHR";
|
||||
<< _.VkErrorID(10868) << spvOpcodeString(opcode)
|
||||
<< ": MemorySemantics must not use Release or AcquireRelease "
|
||||
"memory order with "
|
||||
<< spvOpcodeString(opcode);
|
||||
}
|
||||
|
||||
if (value & uint32_t(spv::MemorySemanticsMask::OutputMemoryKHR) &&
|
||||
!_.HasCapability(spv::Capability::VulkanMemoryModelKHR)) {
|
||||
// In OpenCL, a relaxed fence has no effect but is not explicitly forbidden
|
||||
if (is_vulkan && opcode == spv::Op::OpMemoryBarrier &&
|
||||
!num_memory_order_set_bits) {
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||
<< spvOpcodeString(opcode)
|
||||
<< ": Memory Semantics OutputMemoryKHR requires capability "
|
||||
<< "VulkanMemoryModelKHR";
|
||||
<< _.VkErrorID(10869) << spvOpcodeString(opcode)
|
||||
<< ": MemorySemantics must not use Relaxed memory order with "
|
||||
<< spvOpcodeString(opcode);
|
||||
}
|
||||
|
||||
if (is_vulkan) {
|
||||
const bool includes_storage_class =
|
||||
value & uint32_t(spv::MemorySemanticsMask::UniformMemory |
|
||||
spv::MemorySemanticsMask::WorkgroupMemory |
|
||||
spv::MemorySemanticsMask::ImageMemory |
|
||||
spv::MemorySemanticsMask::OutputMemoryKHR);
|
||||
|
||||
if (num_memory_order_set_bits && !includes_storage_class) {
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||
<< _.VkErrorID(10870) << spvOpcodeString(opcode)
|
||||
<< ": Memory Semantics with a non-relaxed memory order (Acquire, "
|
||||
"Release, or AcquireRelease) must have at least one "
|
||||
"Vulkan-supported storage class semantics bit set "
|
||||
"(UniformMemory, WorkgroupMemory, ImageMemory, or "
|
||||
"OutputMemory)";
|
||||
}
|
||||
|
||||
if (!num_memory_order_set_bits && includes_storage_class) {
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||
<< _.VkErrorID(10871) << spvOpcodeString(opcode)
|
||||
<< ": Memory Semantics with at least one Vulkan-supported "
|
||||
"storage class semantics bit set (UniformMemory, "
|
||||
"WorkgroupMemory, ImageMemory, or OutputMemory) must use "
|
||||
"a non-relaxed memory order (Acquire, Release, or "
|
||||
"AcquireRelease)";
|
||||
}
|
||||
}
|
||||
|
||||
if (value & uint32_t(spv::MemorySemanticsMask::MakeAvailableKHR)) {
|
||||
if (!_.HasCapability(spv::Capability::VulkanMemoryModel)) {
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||
<< spvOpcodeString(opcode)
|
||||
<< ": Memory Semantics MakeAvailableKHR requires capability "
|
||||
<< "VulkanMemoryModelKHR";
|
||||
}
|
||||
if (!(value & uint32_t(spv::MemorySemanticsMask::Release |
|
||||
spv::MemorySemanticsMask::AcquireRelease))) {
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||
<< _.VkErrorID(10872) << spvOpcodeString(opcode)
|
||||
<< ": Memory Semantics with MakeAvailable bit set must use "
|
||||
"Release or AcquireRelease memory order";
|
||||
}
|
||||
}
|
||||
|
||||
if (value & uint32_t(spv::MemorySemanticsMask::MakeVisibleKHR)) {
|
||||
if (!_.HasCapability(spv::Capability::VulkanMemoryModel)) {
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||
<< spvOpcodeString(opcode)
|
||||
<< ": Memory Semantics MakeVisibleKHR requires capability "
|
||||
<< "VulkanMemoryModelKHR";
|
||||
}
|
||||
if (!(value & uint32_t(spv::MemorySemanticsMask::Acquire |
|
||||
spv::MemorySemanticsMask::AcquireRelease))) {
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||
<< _.VkErrorID(10873) << spvOpcodeString(opcode)
|
||||
<< ": Memory Semantics with MakeVisible bit set must use Acquire "
|
||||
"or AcquireRelease memory order";
|
||||
}
|
||||
}
|
||||
|
||||
if (value & uint32_t(spv::MemorySemanticsMask::Volatile)) {
|
||||
if (!_.HasCapability(spv::Capability::VulkanMemoryModelKHR)) {
|
||||
if (!_.HasCapability(spv::Capability::VulkanMemoryModel)) {
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||
<< spvOpcodeString(opcode)
|
||||
<< ": Memory Semantics Volatile requires capability "
|
||||
"VulkanMemoryModelKHR";
|
||||
}
|
||||
|
||||
if (!spvOpcodeIsAtomicOp(inst->opcode())) {
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||
<< "Memory Semantics Volatile can only be used with atomic "
|
||||
"instructions";
|
||||
<< _.VkErrorID(10874) << spvOpcodeString(opcode)
|
||||
<< ": Memory Semantics with Volatile bit set must not be used "
|
||||
"with barrier instructions";
|
||||
}
|
||||
}
|
||||
|
||||
if (value & uint32_t(spv::MemorySemanticsMask::UniformMemory) &&
|
||||
!_.HasCapability(spv::Capability::Shader)) {
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||
<< spvOpcodeString(opcode)
|
||||
<< ": Memory Semantics UniformMemory requires capability Shader";
|
||||
}
|
||||
|
||||
// Checking for spv::Capability::AtomicStorage is intentionally not done here.
|
||||
// See https://github.com/KhronosGroup/glslang/issues/1618 for the reasoning
|
||||
// why.
|
||||
|
||||
if (value & uint32_t(spv::MemorySemanticsMask::MakeAvailableKHR |
|
||||
spv::MemorySemanticsMask::MakeVisibleKHR)) {
|
||||
const bool includes_storage_class =
|
||||
value & uint32_t(spv::MemorySemanticsMask::UniformMemory |
|
||||
spv::MemorySemanticsMask::SubgroupMemory |
|
||||
spv::MemorySemanticsMask::WorkgroupMemory |
|
||||
spv::MemorySemanticsMask::CrossWorkgroupMemory |
|
||||
spv::MemorySemanticsMask::AtomicCounterMemory |
|
||||
spv::MemorySemanticsMask::ImageMemory |
|
||||
spv::MemorySemanticsMask::OutputMemoryKHR);
|
||||
|
||||
if (!includes_storage_class) {
|
||||
if ((opcode == spv::Op::OpAtomicCompareExchange ||
|
||||
opcode == spv::Op::OpAtomicCompareExchangeWeak) &&
|
||||
operand_index == 5) {
|
||||
if (value & uint32_t(spv::MemorySemanticsMask::Release) ||
|
||||
value & uint32_t(spv::MemorySemanticsMask::AcquireRelease)) {
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||
<< spvOpcodeString(opcode)
|
||||
<< ": expected Memory Semantics to include a storage class";
|
||||
<< _.VkErrorID(10875) << spvOpcodeString(opcode)
|
||||
<< " Unequal Memory Semantics must not use Release or "
|
||||
"AcquireRelease memory order";
|
||||
}
|
||||
}
|
||||
|
||||
if (value & uint32_t(spv::MemorySemanticsMask::MakeVisibleKHR) &&
|
||||
!(value & uint32_t(spv::MemorySemanticsMask::Acquire |
|
||||
spv::MemorySemanticsMask::AcquireRelease))) {
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||
<< spvOpcodeString(opcode)
|
||||
<< ": MakeVisibleKHR Memory Semantics also requires either Acquire "
|
||||
"or AcquireRelease Memory Semantics";
|
||||
}
|
||||
bool is_equal_int32 = false;
|
||||
bool is_equal_const = false;
|
||||
uint32_t equal_value = 0;
|
||||
std::tie(is_equal_int32, is_equal_const, equal_value) =
|
||||
_.EvalInt32IfConst(inst->GetOperandAs<uint32_t>(4));
|
||||
|
||||
if (value & uint32_t(spv::MemorySemanticsMask::MakeAvailableKHR) &&
|
||||
!(value & uint32_t(spv::MemorySemanticsMask::Release |
|
||||
spv::MemorySemanticsMask::AcquireRelease))) {
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||
<< spvOpcodeString(opcode)
|
||||
<< ": MakeAvailableKHR Memory Semantics also requires either "
|
||||
"Release or AcquireRelease Memory Semantics";
|
||||
}
|
||||
const auto equal_mask_seq_cst =
|
||||
uint32_t(spv::MemorySemanticsMask::SequentiallyConsistent);
|
||||
const auto equal_mask_acquire = uint32_t(
|
||||
// Allow EqualMemorySemantics Release with UnequalMemorySemantics
|
||||
// Acquire, since the C standard doesn't clearly forbid it.
|
||||
spv::MemorySemanticsMask::SequentiallyConsistent |
|
||||
spv::MemorySemanticsMask::AcquireRelease |
|
||||
spv::MemorySemanticsMask::Release | spv::MemorySemanticsMask::Acquire);
|
||||
|
||||
if (spvIsVulkanEnv(_.context()->target_env)) {
|
||||
const bool includes_storage_class =
|
||||
value & uint32_t(spv::MemorySemanticsMask::UniformMemory |
|
||||
spv::MemorySemanticsMask::WorkgroupMemory |
|
||||
spv::MemorySemanticsMask::ImageMemory |
|
||||
spv::MemorySemanticsMask::OutputMemoryKHR);
|
||||
|
||||
if (opcode == spv::Op::OpMemoryBarrier && !num_memory_order_set_bits) {
|
||||
if (((value & uint32_t(spv::MemorySemanticsMask::SequentiallyConsistent)) &&
|
||||
!(equal_value & equal_mask_seq_cst)) ||
|
||||
((value & uint32_t(spv::MemorySemanticsMask::Acquire)) &&
|
||||
!(equal_value & equal_mask_acquire))) {
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||
<< _.VkErrorID(4732) << spvOpcodeString(opcode)
|
||||
<< ": Vulkan specification requires Memory Semantics to have "
|
||||
"one of the following bits set: Acquire, Release, "
|
||||
"AcquireRelease or SequentiallyConsistent";
|
||||
} else if (opcode != spv::Op::OpMemoryBarrier &&
|
||||
num_memory_order_set_bits) {
|
||||
// should leave only atomics and control barriers for Vulkan env
|
||||
bool memory_is_int32 = false, memory_is_const_int32 = false;
|
||||
uint32_t memory_value = 0;
|
||||
std::tie(memory_is_int32, memory_is_const_int32, memory_value) =
|
||||
_.EvalInt32IfConst(memory_scope);
|
||||
if (memory_is_int32 &&
|
||||
spv::Scope(memory_value) == spv::Scope::Invocation) {
|
||||
<< _.VkErrorID(10876) << spvOpcodeString(opcode)
|
||||
<< " Unequal Memory Semantics must not use a stronger memory "
|
||||
"order than the corresponding Equal Memory Semantics";
|
||||
}
|
||||
|
||||
if (is_vulkan) {
|
||||
auto storage_class_semantics_mask =
|
||||
uint32_t(spv::MemorySemanticsMask::UniformMemory |
|
||||
spv::MemorySemanticsMask::WorkgroupMemory |
|
||||
spv::MemorySemanticsMask::ImageMemory |
|
||||
spv::MemorySemanticsMask::OutputMemoryKHR);
|
||||
|
||||
if (value & ~equal_value & storage_class_semantics_mask) {
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||
<< _.VkErrorID(4641) << spvOpcodeString(opcode)
|
||||
<< ": Vulkan specification requires Memory Semantics to be None "
|
||||
"if used with Invocation Memory Scope";
|
||||
<< _.VkErrorID(10877) << spvOpcodeString(opcode)
|
||||
<< " Unequal Memory Semantics must not have any "
|
||||
"Vulkan-supported storage class semantics bit set "
|
||||
"(UniformMemory, WorkgroupMemory, ImageMemory, or "
|
||||
"OutputMemory) unless this bit is also set in the "
|
||||
"corresponding Equal Memory Semantics";
|
||||
}
|
||||
}
|
||||
|
||||
if (opcode == spv::Op::OpMemoryBarrier && !includes_storage_class) {
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||
<< _.VkErrorID(4733) << spvOpcodeString(opcode)
|
||||
<< ": expected Memory Semantics to include a Vulkan-supported "
|
||||
"storage class";
|
||||
}
|
||||
|
||||
if (opcode == spv::Op::OpControlBarrier && value) {
|
||||
if (!num_memory_order_set_bits) {
|
||||
if (value & ~equal_value &
|
||||
uint32_t(spv::MemorySemanticsMask::MakeVisibleKHR)) {
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||
<< _.VkErrorID(10609) << spvOpcodeString(opcode)
|
||||
<< ": Vulkan specification requires non-zero Memory Semantics "
|
||||
"to have one of the following bits set: Acquire, Release, "
|
||||
"AcquireRelease or SequentiallyConsistent";
|
||||
<< _.VkErrorID(10878) << spvOpcodeString(opcode)
|
||||
<< " Unequal Memory Semantics must not have MakeVisible bit set "
|
||||
"unless this bit is also set in the corresponding Equal "
|
||||
"Memory Semantics";
|
||||
}
|
||||
if (!includes_storage_class) {
|
||||
|
||||
if ((equal_value & uint32_t(spv::MemorySemanticsMask::Volatile)) ^
|
||||
(value & uint32_t(spv::MemorySemanticsMask::Volatile))) {
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||
<< _.VkErrorID(4650) << spvOpcodeString(opcode)
|
||||
<< ": expected Memory Semantics to include a Vulkan-supported "
|
||||
"storage class if Memory Semantics is not None";
|
||||
<< _.VkErrorID(10879) << spvOpcodeString(opcode)
|
||||
<< " Unequal Memory Semantics must have Volatile bit set if and "
|
||||
"only if this bit is also set in the corresponding Equal "
|
||||
"Memory Semantics";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (opcode == spv::Op::OpAtomicFlagClear &&
|
||||
(value & uint32_t(spv::MemorySemanticsMask::Acquire) ||
|
||||
value & uint32_t(spv::MemorySemanticsMask::AcquireRelease))) {
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||
<< "Memory Semantics Acquire and AcquireRelease cannot be used "
|
||||
"with "
|
||||
<< spvOpcodeString(opcode);
|
||||
}
|
||||
|
||||
if (opcode == spv::Op::OpAtomicCompareExchange && operand_index == 5 &&
|
||||
(value & uint32_t(spv::MemorySemanticsMask::Release) ||
|
||||
value & uint32_t(spv::MemorySemanticsMask::AcquireRelease))) {
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||
<< spvOpcodeString(opcode)
|
||||
<< ": Memory Semantics Release and AcquireRelease cannot be "
|
||||
"used "
|
||||
"for operand Unequal";
|
||||
}
|
||||
|
||||
if (spvIsVulkanEnv(_.context()->target_env)) {
|
||||
if (opcode == spv::Op::OpAtomicLoad &&
|
||||
(value & uint32_t(spv::MemorySemanticsMask::Release) ||
|
||||
value & uint32_t(spv::MemorySemanticsMask::AcquireRelease) ||
|
||||
value & uint32_t(spv::MemorySemanticsMask::SequentiallyConsistent))) {
|
||||
if (is_vulkan && num_memory_order_set_bits) {
|
||||
bool memory_is_int32 = false, memory_is_const_int32 = false;
|
||||
uint32_t memory_value = 0;
|
||||
std::tie(memory_is_int32, memory_is_const_int32, memory_value) =
|
||||
_.EvalInt32IfConst(memory_scope);
|
||||
if (memory_is_int32 && spv::Scope(memory_value) == spv::Scope::Invocation) {
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||
<< _.VkErrorID(4731)
|
||||
<< "Vulkan spec disallows OpAtomicLoad with Memory Semantics "
|
||||
"Release, AcquireRelease and SequentiallyConsistent";
|
||||
}
|
||||
|
||||
if (opcode == spv::Op::OpAtomicStore &&
|
||||
(value & uint32_t(spv::MemorySemanticsMask::Acquire) ||
|
||||
value & uint32_t(spv::MemorySemanticsMask::AcquireRelease) ||
|
||||
value & uint32_t(spv::MemorySemanticsMask::SequentiallyConsistent))) {
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||
<< _.VkErrorID(4730)
|
||||
<< "Vulkan spec disallows OpAtomicStore with Memory Semantics "
|
||||
"Acquire, AcquireRelease and SequentiallyConsistent";
|
||||
<< _.VkErrorID(4641) << spvOpcodeString(opcode)
|
||||
<< ": Vulkan specification requires Memory Semantics to be "
|
||||
"Relaxed if used with Invocation Memory Scope";
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -59,20 +59,22 @@ spv_result_t ValidateEntryPoint(ValidationState_t& _, const Instruction* inst) {
|
||||
}
|
||||
|
||||
const auto* execution_modes = _.GetExecutionModes(entry_point_id);
|
||||
auto has_mode = [&execution_modes](spv::ExecutionMode mode) {
|
||||
return execution_modes && execution_modes->count(mode);
|
||||
};
|
||||
|
||||
if (_.HasCapability(spv::Capability::Shader)) {
|
||||
switch (execution_model) {
|
||||
case spv::ExecutionModel::Fragment:
|
||||
if (execution_modes &&
|
||||
execution_modes->count(spv::ExecutionMode::OriginUpperLeft) &&
|
||||
execution_modes->count(spv::ExecutionMode::OriginLowerLeft)) {
|
||||
if (has_mode(spv::ExecutionMode::OriginUpperLeft) &&
|
||||
has_mode(spv::ExecutionMode::OriginLowerLeft)) {
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||
<< "Fragment execution model entry points can only specify "
|
||||
"one of OriginUpperLeft or OriginLowerLeft execution "
|
||||
"modes.";
|
||||
}
|
||||
if (!execution_modes ||
|
||||
(!execution_modes->count(spv::ExecutionMode::OriginUpperLeft) &&
|
||||
!execution_modes->count(spv::ExecutionMode::OriginLowerLeft))) {
|
||||
if (!has_mode(spv::ExecutionMode::OriginUpperLeft) &&
|
||||
!has_mode(spv::ExecutionMode::OriginLowerLeft)) {
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||
<< "Fragment execution model entry points require either an "
|
||||
"OriginUpperLeft or OriginLowerLeft execution mode.";
|
||||
@@ -285,36 +287,31 @@ spv_result_t ValidateEntryPoint(ValidationState_t& _, const Instruction* inst) {
|
||||
}
|
||||
}
|
||||
|
||||
bool has_workgroup_size = false;
|
||||
bool has_local_size_id = false;
|
||||
for (auto& i : _.ordered_instructions()) {
|
||||
if (i.opcode() == spv::Op::OpFunction) break;
|
||||
if (i.opcode() == spv::Op::OpDecorate && i.operands().size() > 2) {
|
||||
if (i.GetOperandAs<spv::Decoration>(1) == spv::Decoration::BuiltIn &&
|
||||
i.GetOperandAs<spv::BuiltIn>(2) == spv::BuiltIn::WorkgroupSize) {
|
||||
has_workgroup_size = true;
|
||||
}
|
||||
}
|
||||
if (i.opcode() == spv::Op::OpExecutionModeId) {
|
||||
if (i.GetOperandAs<spv::ExecutionMode>(1) ==
|
||||
spv::ExecutionMode::LocalSizeId) {
|
||||
has_local_size_id = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (spvIsVulkanEnv(_.context()->target_env)) {
|
||||
switch (execution_model) {
|
||||
case spv::ExecutionModel::GLCompute:
|
||||
if (!execution_modes ||
|
||||
!execution_modes->count(spv::ExecutionMode::LocalSize)) {
|
||||
bool ok = false;
|
||||
for (auto& i : _.ordered_instructions()) {
|
||||
if (i.opcode() == spv::Op::OpDecorate) {
|
||||
if (i.operands().size() > 2) {
|
||||
if (i.GetOperandAs<spv::Decoration>(1) ==
|
||||
spv::Decoration::BuiltIn &&
|
||||
i.GetOperandAs<spv::BuiltIn>(2) ==
|
||||
spv::BuiltIn::WorkgroupSize) {
|
||||
ok = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (i.opcode() == spv::Op::OpExecutionModeId) {
|
||||
const auto mode = i.GetOperandAs<spv::ExecutionMode>(1);
|
||||
if (mode == spv::ExecutionMode::LocalSizeId) {
|
||||
ok = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (!has_mode(spv::ExecutionMode::LocalSize)) {
|
||||
bool ok = has_workgroup_size || has_local_size_id;
|
||||
if (!ok && _.HasCapability(spv::Capability::TileShadingQCOM)) {
|
||||
ok =
|
||||
execution_modes &&
|
||||
execution_modes->count(spv::ExecutionMode::TileShadingRateQCOM);
|
||||
ok = has_mode(spv::ExecutionMode::TileShadingRateQCOM);
|
||||
}
|
||||
if (!ok) {
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||
@@ -332,25 +329,20 @@ spv_result_t ValidateEntryPoint(ValidationState_t& _, const Instruction* inst) {
|
||||
}
|
||||
|
||||
if (_.HasCapability(spv::Capability::TileShadingQCOM)) {
|
||||
if (execution_modes) {
|
||||
if (execution_modes->count(
|
||||
spv::ExecutionMode::TileShadingRateQCOM) &&
|
||||
(execution_modes->count(spv::ExecutionMode::LocalSize) ||
|
||||
execution_modes->count(spv::ExecutionMode::LocalSizeId))) {
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||
<< "If the TileShadingRateQCOM execution mode is used, "
|
||||
<< "LocalSize and LocalSizeId must not be specified.";
|
||||
}
|
||||
if (execution_modes->count(
|
||||
spv::ExecutionMode::NonCoherentTileAttachmentReadQCOM)) {
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||
<< "The NonCoherentTileAttachmentQCOM execution mode must "
|
||||
"not be used in any stage other than fragment.";
|
||||
}
|
||||
if (has_mode(spv::ExecutionMode::TileShadingRateQCOM) &&
|
||||
(has_mode(spv::ExecutionMode::LocalSize) ||
|
||||
has_mode(spv::ExecutionMode::LocalSizeId))) {
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||
<< "If the TileShadingRateQCOM execution mode is used, "
|
||||
<< "LocalSize and LocalSizeId must not be specified.";
|
||||
}
|
||||
if (has_mode(spv::ExecutionMode::NonCoherentTileAttachmentReadQCOM)) {
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||
<< "The NonCoherentTileAttachmentQCOM execution mode must "
|
||||
"not be used in any stage other than fragment.";
|
||||
}
|
||||
} else {
|
||||
if (execution_modes &&
|
||||
execution_modes->count(spv::ExecutionMode::TileShadingRateQCOM)) {
|
||||
if (has_mode(spv::ExecutionMode::TileShadingRateQCOM)) {
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||
<< "If the TileShadingRateQCOM execution mode is used, the "
|
||||
"TileShadingQCOM capability must be enabled.";
|
||||
@@ -358,16 +350,13 @@ spv_result_t ValidateEntryPoint(ValidationState_t& _, const Instruction* inst) {
|
||||
}
|
||||
break;
|
||||
default:
|
||||
if (execution_modes &&
|
||||
execution_modes->count(spv::ExecutionMode::TileShadingRateQCOM)) {
|
||||
if (has_mode(spv::ExecutionMode::TileShadingRateQCOM)) {
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||
<< "The TileShadingRateQCOM execution mode must not be used "
|
||||
"in any stage other than compute.";
|
||||
}
|
||||
if (execution_model != spv::ExecutionModel::Fragment) {
|
||||
if (execution_modes &&
|
||||
execution_modes->count(
|
||||
spv::ExecutionMode::NonCoherentTileAttachmentReadQCOM)) {
|
||||
if (has_mode(spv::ExecutionMode::NonCoherentTileAttachmentReadQCOM)) {
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||
<< "The NonCoherentTileAttachmentQCOM execution mode must "
|
||||
"not be used in any stage other than fragment.";
|
||||
@@ -378,9 +367,7 @@ spv_result_t ValidateEntryPoint(ValidationState_t& _, const Instruction* inst) {
|
||||
"any stage other than compute or fragment.";
|
||||
}
|
||||
} else {
|
||||
if (execution_modes &&
|
||||
execution_modes->count(
|
||||
spv::ExecutionMode::NonCoherentTileAttachmentReadQCOM)) {
|
||||
if (has_mode(spv::ExecutionMode::NonCoherentTileAttachmentReadQCOM)) {
|
||||
if (!_.HasCapability(spv::Capability::TileShadingQCOM)) {
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||
<< "If the NonCoherentTileAttachmentReadQCOM execution "
|
||||
@@ -393,7 +380,9 @@ spv_result_t ValidateEntryPoint(ValidationState_t& _, const Instruction* inst) {
|
||||
}
|
||||
}
|
||||
|
||||
if (_.EntryPointHasLocalSizeOrId(entry_point_id)) {
|
||||
// WorkgroupSize decoration takes precedence over any LocalSize or LocalSizeId
|
||||
// execution mode, so the values can be ignored
|
||||
if (_.EntryPointHasLocalSizeOrId(entry_point_id) && !has_workgroup_size) {
|
||||
const Instruction* local_size_inst =
|
||||
_.EntryPointLocalSizeOrId(entry_point_id);
|
||||
if (local_size_inst) {
|
||||
@@ -402,7 +391,8 @@ spv_result_t ValidateEntryPoint(ValidationState_t& _, const Instruction* inst) {
|
||||
const uint32_t operand_y = local_size_inst->GetOperandAs<uint32_t>(3);
|
||||
const uint32_t operand_z = local_size_inst->GetOperandAs<uint32_t>(4);
|
||||
if (mode == spv::ExecutionMode::LocalSize) {
|
||||
if ((operand_x * operand_y * operand_z) == 0) {
|
||||
const uint64_t product_size = operand_x * operand_y * operand_z;
|
||||
if (product_size == 0) {
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, local_size_inst)
|
||||
<< "Local Size execution mode must not have a product of zero "
|
||||
"(X "
|
||||
@@ -410,6 +400,32 @@ spv_result_t ValidateEntryPoint(ValidationState_t& _, const Instruction* inst) {
|
||||
<< operand_x << ", Y = " << operand_y << ", Z = " << operand_z
|
||||
<< ").";
|
||||
}
|
||||
if (has_mode(spv::ExecutionMode::DerivativeGroupQuadsKHR)) {
|
||||
if (operand_x % 2 != 0 || operand_y % 2 != 0) {
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, local_size_inst)
|
||||
<< _.VkErrorID(10151)
|
||||
<< "Local Size execution mode dimensions is "
|
||||
"(X = "
|
||||
<< operand_x << ", Y = " << operand_y
|
||||
<< ") but Entry Point id " << entry_point_id
|
||||
<< " also has an DerivativeGroupQuadsKHR execution mode, so "
|
||||
"both dimensions must be a multiple of 2";
|
||||
}
|
||||
}
|
||||
if (has_mode(spv::ExecutionMode::DerivativeGroupLinearKHR)) {
|
||||
if (product_size % 4 != 0) {
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, local_size_inst)
|
||||
<< _.VkErrorID(10152)
|
||||
<< "Local Size execution mode dimensions is (X = "
|
||||
<< operand_x << ", Y = " << operand_y
|
||||
<< ", Z = " << operand_z << ") but Entry Point id "
|
||||
<< entry_point_id
|
||||
<< " also has an DerivativeGroupLinearKHR execution mode, "
|
||||
"so "
|
||||
"the product ("
|
||||
<< product_size << ") must be a multiple of 4";
|
||||
}
|
||||
}
|
||||
} else if (mode == spv::ExecutionMode::LocalSizeId) {
|
||||
// can only validate product if static and not spec constant
|
||||
// (This is done for us in EvalConstantValUint64)
|
||||
@@ -417,13 +433,42 @@ spv_result_t ValidateEntryPoint(ValidationState_t& _, const Instruction* inst) {
|
||||
bool static_x = _.EvalConstantValUint64(operand_x, &x_size);
|
||||
bool static_y = _.EvalConstantValUint64(operand_y, &y_size);
|
||||
bool static_z = _.EvalConstantValUint64(operand_z, &z_size);
|
||||
if (static_x && static_y && static_z &&
|
||||
((x_size * y_size * z_size) == 0)) {
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, local_size_inst)
|
||||
<< "Local Size Id execution mode must not have a product of "
|
||||
"zero "
|
||||
"(X = "
|
||||
<< x_size << ", Y = " << y_size << ", Z = " << z_size << ").";
|
||||
if (static_x && static_y && static_z) {
|
||||
const uint64_t product_size = x_size * y_size * z_size;
|
||||
if (product_size == 0) {
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, local_size_inst)
|
||||
<< "LocalSizeId execution mode must not have a product of "
|
||||
"zero "
|
||||
"(X = "
|
||||
<< x_size << ", Y = " << y_size << ", Z = " << z_size
|
||||
<< ").";
|
||||
}
|
||||
if (has_mode(spv::ExecutionMode::DerivativeGroupQuadsKHR)) {
|
||||
if (x_size % 2 != 0 || y_size % 2 != 0) {
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, local_size_inst)
|
||||
<< _.VkErrorID(10151)
|
||||
<< "LocalSizeId execution mode dimensions is "
|
||||
"(X = "
|
||||
<< x_size << ", Y = " << y_size << ") but Entry Point id "
|
||||
<< entry_point_id
|
||||
<< " also has an DerivativeGroupQuadsKHR execution mode, "
|
||||
"so "
|
||||
"both dimensions must be a multiple of 2";
|
||||
}
|
||||
}
|
||||
if (has_mode(spv::ExecutionMode::DerivativeGroupLinearKHR)) {
|
||||
if (product_size % 4 != 0) {
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, local_size_inst)
|
||||
<< _.VkErrorID(10152)
|
||||
<< "LocalSizeId execution mode dimensions is (X = "
|
||||
<< x_size << ", Y = " << y_size << ", Z = " << z_size
|
||||
<< ") but Entry Point id " << entry_point_id
|
||||
<< " also has an DerivativeGroupLinearKHR execution mode, "
|
||||
"so "
|
||||
"the product ("
|
||||
<< product_size << ") must be a multiple of 4";
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -557,6 +602,7 @@ spv_result_t ValidateExecutionMode(ValidationState_t& _,
|
||||
"Operands that are not id operands.";
|
||||
}
|
||||
|
||||
const bool is_vulkan_env = (spvIsVulkanEnv(_.context()->target_env));
|
||||
const auto* models = _.GetExecutionModels(entry_point_id);
|
||||
switch (mode) {
|
||||
case spv::ExecutionMode::Invocations:
|
||||
@@ -667,7 +713,7 @@ spv_result_t ValidateExecutionMode(ValidationState_t& _,
|
||||
"tessellation execution model.";
|
||||
}
|
||||
}
|
||||
if (spvIsVulkanEnv(_.context()->target_env)) {
|
||||
if (is_vulkan_env) {
|
||||
if (_.HasCapability(spv::Capability::MeshShadingEXT) &&
|
||||
inst->GetOperandAs<uint32_t>(2) == 0) {
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||
@@ -690,8 +736,7 @@ spv_result_t ValidateExecutionMode(ValidationState_t& _,
|
||||
"execution "
|
||||
"model.";
|
||||
}
|
||||
if (mode == spv::ExecutionMode::OutputPrimitivesEXT &&
|
||||
spvIsVulkanEnv(_.context()->target_env)) {
|
||||
if (mode == spv::ExecutionMode::OutputPrimitivesEXT && is_vulkan_env) {
|
||||
if (_.HasCapability(spv::Capability::MeshShadingEXT) &&
|
||||
inst->GetOperandAs<uint32_t>(2) == 0) {
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||
@@ -761,9 +806,15 @@ spv_result_t ValidateExecutionMode(ValidationState_t& _,
|
||||
break;
|
||||
case spv::ExecutionMode::LocalSize:
|
||||
case spv::ExecutionMode::LocalSizeId:
|
||||
if (mode == spv::ExecutionMode::LocalSizeId && !_.IsLocalSizeIdAllowed())
|
||||
if (mode == spv::ExecutionMode::LocalSizeId &&
|
||||
!_.IsLocalSizeIdAllowed()) {
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||
<< "LocalSizeId mode is not allowed by the current environment.";
|
||||
<< "LocalSizeId mode is not allowed by the current environment."
|
||||
<< (is_vulkan_env
|
||||
? _.MissingFeature("maintenance4 feature",
|
||||
"--allow-localsizeid", false)
|
||||
: "");
|
||||
}
|
||||
|
||||
if (!std::all_of(
|
||||
models->begin(), models->end(),
|
||||
@@ -812,7 +863,7 @@ spv_result_t ValidateExecutionMode(ValidationState_t& _,
|
||||
}
|
||||
}
|
||||
|
||||
if (spvIsVulkanEnv(_.context()->target_env)) {
|
||||
if (is_vulkan_env) {
|
||||
if (mode == spv::ExecutionMode::OriginLowerLeft) {
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||
<< _.VkErrorID(4653)
|
||||
|
||||
@@ -130,7 +130,7 @@ spv_result_t ValidateGroupNonUniformBroadcastShuffle(ValidationState_t& _,
|
||||
if (!spvOpcodeIsConstant(id_op)) {
|
||||
std::string operand = GetOperandName(inst->opcode());
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||
<< "Before SPIR-V 1.5, " << operand
|
||||
<< "In SPIR-V 1.4 or earlier, " << operand
|
||||
<< " must be a constant instruction";
|
||||
}
|
||||
}
|
||||
|
||||
@@ -94,7 +94,7 @@ spv_result_t ValidateExecutionScope(ValidationState_t& _,
|
||||
|
||||
// Vulkan specific rules
|
||||
if (spvIsVulkanEnv(_.context()->target_env)) {
|
||||
// Vulkan 1.1 specific rules
|
||||
// Subgroups were not added until 1.1
|
||||
if (_.context()->target_env != SPV_ENV_VULKAN_1_0) {
|
||||
// Scope for Non Uniform Group Operations must be limited to Subgroup
|
||||
if ((spvOpcodeIsNonUniformGroupOperation(opcode) &&
|
||||
|
||||
@@ -83,8 +83,7 @@ spv_result_t ValidateTensorRead(ValidationState_t& _, const Instruction* inst) {
|
||||
auto op_coord = inst->word(4);
|
||||
auto inst_coord = _.FindDef(op_coord);
|
||||
auto tensor_rank = GetTensorTypeRank(_, inst_tensor->type_id());
|
||||
if (tensor_rank == 0 ||
|
||||
!_.IsIntArrayType(inst_coord->type_id(), tensor_rank)) {
|
||||
if (!_.IsIntArrayType(inst_coord->type_id(), tensor_rank)) {
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||
<< "Expected Coordinates to be an array whose Element Type is an "
|
||||
"integer type and whose Length is equal to the Rank of Tensor.";
|
||||
@@ -143,8 +142,7 @@ spv_result_t ValidateTensorWrite(ValidationState_t& _,
|
||||
auto op_coord = inst->word(2);
|
||||
auto inst_coord = _.FindDef(op_coord);
|
||||
auto tensor_rank = GetTensorTypeRank(_, inst_tensor->type_id());
|
||||
if (tensor_rank == 0 ||
|
||||
!_.IsIntArrayType(inst_coord->type_id(), tensor_rank)) {
|
||||
if (!_.IsIntArrayType(inst_coord->type_id(), tensor_rank)) {
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||
<< "Expected Coordinates to be an array whose Element Type is an "
|
||||
"integer type and whose Length is equal to the Rank of Tensor.";
|
||||
|
||||
@@ -140,7 +140,7 @@ spv_result_t ValidateTypeFloat(ValidationState_t& _, const Instruction* inst) {
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||
<< "8-bit floating point type requires an encoding.";
|
||||
}
|
||||
const spvtools::OperandDesc* desc;
|
||||
const spvtools::OperandDesc* desc = nullptr;
|
||||
const std::set<spv::FPEncoding> known_encodings{
|
||||
spv::FPEncoding::Float8E4M3EXT, spv::FPEncoding::Float8E5M2EXT};
|
||||
spv_result_t status = spvtools::LookupOperand(SPV_OPERAND_TYPE_FPENCODING,
|
||||
@@ -433,10 +433,9 @@ spv_result_t ValidateTypeStruct(ValidationState_t& _, const Instruction* inst) {
|
||||
<< "Structure <id> " << _.getIdName(member_type_id)
|
||||
<< " contains members with BuiltIn decoration. Therefore this "
|
||||
<< "structure may not be contained as a member of another "
|
||||
<< "structure "
|
||||
<< "type. Structure <id> " << _.getIdName(struct_id)
|
||||
<< " contains structure <id> " << _.getIdName(member_type_id)
|
||||
<< ".";
|
||||
<< "structure " << "type. Structure <id> "
|
||||
<< _.getIdName(struct_id) << " contains structure <id> "
|
||||
<< _.getIdName(member_type_id) << ".";
|
||||
}
|
||||
|
||||
if (spvIsVulkanEnv(_.context()->target_env) &&
|
||||
@@ -562,6 +561,9 @@ spv_result_t ValidateTypePointer(ValidationState_t& _,
|
||||
// a storage image.
|
||||
if (sampled == 2) _.RegisterPointerToStorageImage(inst->id());
|
||||
}
|
||||
if (type->opcode() == spv::Op::OpTypeTensorARM) {
|
||||
_.RegisterPointerToTensor(inst->id());
|
||||
}
|
||||
}
|
||||
|
||||
if (!_.IsValidStorageClass(storage_class)) {
|
||||
@@ -614,6 +616,7 @@ spv_result_t ValidateTypeFunction(ValidationState_t& _,
|
||||
for (auto& pair : inst->uses()) {
|
||||
const auto* use = pair.first;
|
||||
if (use->opcode() != spv::Op::OpFunction &&
|
||||
use->opcode() != spv::Op::OpAsmINTEL &&
|
||||
!spvOpcodeIsDebug(use->opcode()) && !use->IsNonSemantic() &&
|
||||
!spvOpcodeIsDecoration(use->opcode())) {
|
||||
return _.diag(SPV_ERROR_INVALID_ID, use)
|
||||
|
||||
280
3rdparty/spirv-tools/source/val/validation_state.cpp
vendored
280
3rdparty/spirv-tools/source/val/validation_state.cpp
vendored
@@ -42,14 +42,17 @@ ModuleLayoutSection InstructionLayoutSection(
|
||||
|
||||
switch (op) {
|
||||
case spv::Op::OpCapability:
|
||||
case spv::Op::OpConditionalCapabilityINTEL:
|
||||
return kLayoutCapabilities;
|
||||
case spv::Op::OpExtension:
|
||||
case spv::Op::OpConditionalExtensionINTEL:
|
||||
return kLayoutExtensions;
|
||||
case spv::Op::OpExtInstImport:
|
||||
return kLayoutExtInstImport;
|
||||
case spv::Op::OpMemoryModel:
|
||||
return kLayoutMemoryModel;
|
||||
case spv::Op::OpEntryPoint:
|
||||
case spv::Op::OpConditionalEntryPointINTEL:
|
||||
return kLayoutEntryPoint;
|
||||
case spv::Op::OpExecutionMode:
|
||||
case spv::Op::OpExecutionModeId:
|
||||
@@ -85,6 +88,9 @@ ModuleLayoutSection InstructionLayoutSection(
|
||||
// spv::Op::OpExtInst is only allowed in types section for certain
|
||||
// extended instruction sets. This will be checked separately.
|
||||
if (current_section == kLayoutTypes) return kLayoutTypes;
|
||||
// SpvOpExtInst is allowed in graph definitions.
|
||||
if (current_section == kLayoutGraphDefinitions)
|
||||
return kLayoutGraphDefinitions;
|
||||
return kLayoutFunctionDefinitions;
|
||||
case spv::Op::OpLine:
|
||||
case spv::Op::OpNoLine:
|
||||
@@ -99,6 +105,16 @@ ModuleLayoutSection InstructionLayoutSection(
|
||||
return kLayoutFunctionDefinitions;
|
||||
case spv::Op::OpSamplerImageAddressingModeNV:
|
||||
return kLayoutSamplerImageAddressMode;
|
||||
case spv::Op::OpGraphEntryPointARM:
|
||||
case spv::Op::OpGraphARM:
|
||||
case spv::Op::OpGraphInputARM:
|
||||
case spv::Op::OpGraphSetOutputARM:
|
||||
case spv::Op::OpGraphEndARM:
|
||||
return kLayoutGraphDefinitions;
|
||||
case spv::Op::OpCompositeExtract:
|
||||
if (current_section == kLayoutGraphDefinitions)
|
||||
return kLayoutGraphDefinitions;
|
||||
return kLayoutFunctionDefinitions;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
@@ -174,6 +190,7 @@ ValidationState_t::ValidationState_t(const spv_const_context ctx,
|
||||
pointer_size_and_alignment_(0),
|
||||
sampler_image_addressing_mode_(0),
|
||||
in_function_(false),
|
||||
graph_definition_region_(kGraphDefinitionOutside),
|
||||
num_of_warnings_(0),
|
||||
max_num_of_warnings_(max_warnings) {
|
||||
assert(opt && "Validator options may not be Null.");
|
||||
@@ -362,6 +379,10 @@ bool ValidationState_t::in_block() const {
|
||||
module_functions_.back().current_block() != nullptr;
|
||||
}
|
||||
|
||||
GraphDefinitionRegion ValidationState_t::graph_definition_region() const {
|
||||
return graph_definition_region_;
|
||||
}
|
||||
|
||||
void ValidationState_t::RegisterCapability(spv::Capability cap) {
|
||||
// Avoid redundant work. Otherwise the recursion could induce work
|
||||
// quadrdatic in the capability dependency depth. (Ok, not much, but
|
||||
@@ -532,6 +553,13 @@ spv_result_t ValidationState_t::RegisterFunctionEnd() {
|
||||
return SPV_SUCCESS;
|
||||
}
|
||||
|
||||
void ValidationState_t::SetGraphDefinitionRegion(GraphDefinitionRegion region) {
|
||||
assert((region == kGraphDefinitionOutside &&
|
||||
graph_definition_region_ == kGraphDefinitionOutputs) ||
|
||||
region >= graph_definition_region_);
|
||||
graph_definition_region_ = region;
|
||||
}
|
||||
|
||||
Instruction* ValidationState_t::AddOrderedInstruction(
|
||||
const spv_parsed_instruction_t* inst) {
|
||||
ordered_instructions_.emplace_back(inst);
|
||||
@@ -875,9 +903,12 @@ uint32_t ValidationState_t::GetComponentType(uint32_t id) const {
|
||||
case spv::Op::OpTypeFloat:
|
||||
case spv::Op::OpTypeInt:
|
||||
case spv::Op::OpTypeBool:
|
||||
case spv::Op::OpTypePointer:
|
||||
case spv::Op::OpTypeUntypedPointerKHR:
|
||||
return id;
|
||||
|
||||
case spv::Op::OpTypeArray:
|
||||
case spv::Op::OpTypeRuntimeArray:
|
||||
return inst->word(2);
|
||||
|
||||
case spv::Op::OpTypeVector:
|
||||
@@ -939,11 +970,20 @@ uint32_t ValidationState_t::GetBitWidth(uint32_t id) const {
|
||||
const Instruction* inst = FindDef(component_type_id);
|
||||
assert(inst);
|
||||
|
||||
if (inst->opcode() == spv::Op::OpTypeFloat ||
|
||||
inst->opcode() == spv::Op::OpTypeInt)
|
||||
return inst->word(2);
|
||||
|
||||
if (inst->opcode() == spv::Op::OpTypeBool) return 1;
|
||||
switch (inst->opcode()) {
|
||||
case spv::Op::OpTypeFloat:
|
||||
case spv::Op::OpTypeInt:
|
||||
return inst->word(2);
|
||||
case spv::Op::OpTypeBool:
|
||||
return 1;
|
||||
case spv::Op::OpTypePointer:
|
||||
case spv::Op::OpTypeUntypedPointerKHR:
|
||||
assert(inst->GetOperandAs<spv::StorageClass>(1) ==
|
||||
spv::StorageClass::PhysicalStorageBuffer);
|
||||
return 64; // all pointers to another PSB is 64-bit
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
||||
assert(0);
|
||||
return 0;
|
||||
@@ -958,6 +998,23 @@ bool ValidationState_t::IsScalarType(uint32_t id) const {
|
||||
return IsIntScalarType(id) || IsFloatScalarType(id) || IsBoolScalarType(id);
|
||||
}
|
||||
|
||||
bool ValidationState_t::IsArrayType(uint32_t id, uint64_t length) const {
|
||||
const Instruction* inst = FindDef(id);
|
||||
if (!inst || inst->opcode() != spv::Op::OpTypeArray) {
|
||||
return false;
|
||||
}
|
||||
if (length != 0) {
|
||||
const auto len_id = inst->GetOperandAs<uint32_t>(2);
|
||||
const auto len = FindDef(len_id);
|
||||
uint64_t len_value = 0;
|
||||
if (!len || !spvOpcodeIsConstant(len->opcode()) ||
|
||||
(EvalConstantValUint64(len_id, &len_value) && (length != len_value))) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool ValidationState_t::IsBfloat16ScalarType(uint32_t id) const {
|
||||
const Instruction* inst = FindDef(id);
|
||||
if (inst && inst->opcode() == spv::Op::OpTypeFloat) {
|
||||
@@ -984,6 +1041,24 @@ bool ValidationState_t::IsBfloat16VectorType(uint32_t id) const {
|
||||
return false;
|
||||
}
|
||||
|
||||
bool ValidationState_t::IsBfloat16CoopMatType(uint32_t id) const {
|
||||
const Instruction* inst = FindDef(id);
|
||||
if (!inst) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (inst->opcode() == spv::Op::OpTypeCooperativeMatrixKHR) {
|
||||
return IsBfloat16ScalarType(inst->word(2));
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
bool ValidationState_t::IsBfloat16Type(uint32_t id) const {
|
||||
return IsBfloat16ScalarType(id) || IsBfloat16VectorType(id) ||
|
||||
IsBfloat16CoopMatType(id);
|
||||
}
|
||||
|
||||
bool ValidationState_t::IsFP8ScalarType(uint32_t id) const {
|
||||
const Instruction* inst = FindDef(id);
|
||||
if (inst && inst->opcode() == spv::Op::OpTypeFloat) {
|
||||
@@ -1011,8 +1086,21 @@ bool ValidationState_t::IsFP8VectorType(uint32_t id) const {
|
||||
return false;
|
||||
}
|
||||
|
||||
bool ValidationState_t::IsFP8ScalarOrVectorType(uint32_t id) const {
|
||||
return IsFP8ScalarType(id) || IsFP8VectorType(id);
|
||||
bool ValidationState_t::IsFP8CoopMatType(uint32_t id) const {
|
||||
const Instruction* inst = FindDef(id);
|
||||
if (!inst) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (inst->opcode() == spv::Op::OpTypeCooperativeMatrixKHR) {
|
||||
return IsFP8ScalarType(inst->word(2));
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
bool ValidationState_t::IsFP8Type(uint32_t id) const {
|
||||
return IsFP8ScalarType(id) || IsFP8VectorType(id) || IsFP8CoopMatType(id);
|
||||
}
|
||||
|
||||
bool ValidationState_t::IsFloatScalarType(uint32_t id) const {
|
||||
@@ -1021,16 +1109,7 @@ bool ValidationState_t::IsFloatScalarType(uint32_t id) const {
|
||||
}
|
||||
|
||||
bool ValidationState_t::IsFloatArrayType(uint32_t id) const {
|
||||
const Instruction* inst = FindDef(id);
|
||||
if (!inst) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (inst->opcode() == spv::Op::OpTypeArray) {
|
||||
return IsFloatScalarType(GetComponentType(id));
|
||||
}
|
||||
|
||||
return false;
|
||||
return IsArrayType(id) && IsFloatScalarType(GetComponentType(id));
|
||||
}
|
||||
|
||||
bool ValidationState_t::IsFloatVectorType(uint32_t id) const {
|
||||
@@ -1077,36 +1156,27 @@ bool ValidationState_t::IsFloatScalarOrVectorType(uint32_t id) const {
|
||||
return false;
|
||||
}
|
||||
|
||||
bool ValidationState_t::IsIntScalarType(uint32_t id) const {
|
||||
bool ValidationState_t::IsIntScalarType(uint32_t id, uint32_t width) const {
|
||||
const Instruction* inst = FindDef(id);
|
||||
return inst && inst->opcode() == spv::Op::OpTypeInt;
|
||||
bool is_int = inst && inst->opcode() == spv::Op::OpTypeInt;
|
||||
if (!is_int) {
|
||||
return false;
|
||||
}
|
||||
if ((width != 0) && (width != inst->word(2))) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool ValidationState_t::IsIntScalarTypeWithSignedness(
|
||||
uint32_t id, uint32_t signedness) const {
|
||||
const Instruction* inst = FindDef(id);
|
||||
return inst && inst->opcode() == spv::Op::OpTypeInt &&
|
||||
inst->word(3) == signedness;
|
||||
}
|
||||
|
||||
bool ValidationState_t::IsIntArrayType(uint32_t id, uint64_t length) const {
|
||||
const Instruction* inst = FindDef(id);
|
||||
if (!inst) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (inst->opcode() != spv::Op::OpTypeArray) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!IsIntScalarType(GetComponentType(id))) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (length != 0) {
|
||||
const auto len_id = inst->GetOperandAs<uint32_t>(2);
|
||||
const auto len = FindDef(len_id);
|
||||
uint64_t len_value = 0;
|
||||
if (!len || !spvOpcodeIsConstant(len->opcode()) ||
|
||||
(EvalConstantValUint64(len_id, &len_value) && (length != len_value))) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
return IsArrayType(id, length) && IsIntScalarType(GetComponentType(id));
|
||||
}
|
||||
|
||||
bool ValidationState_t::IsIntVectorType(uint32_t id) const {
|
||||
@@ -1140,8 +1210,7 @@ bool ValidationState_t::IsIntScalarOrVectorType(uint32_t id) const {
|
||||
}
|
||||
|
||||
bool ValidationState_t::IsUnsignedIntScalarType(uint32_t id) const {
|
||||
const Instruction* inst = FindDef(id);
|
||||
return inst && inst->opcode() == spv::Op::OpTypeInt && inst->word(3) == 0;
|
||||
return IsIntScalarTypeWithSignedness(id, 0);
|
||||
}
|
||||
|
||||
bool ValidationState_t::IsUnsignedIntVectorType(uint32_t id) const {
|
||||
@@ -1312,6 +1381,28 @@ bool ValidationState_t::GetPointerTypeInfo(
|
||||
return true;
|
||||
}
|
||||
|
||||
uint32_t ValidationState_t::GetLargestScalarType(uint32_t id) const {
|
||||
const Instruction* inst = FindDef(id);
|
||||
|
||||
switch (inst->opcode()) {
|
||||
case spv::Op::OpTypeStruct: {
|
||||
uint32_t size = 0;
|
||||
for (uint32_t i = 1; i < inst->operands().size(); ++i) {
|
||||
const uint32_t member_size =
|
||||
GetLargestScalarType(inst->GetOperandAs<uint32_t>(i));
|
||||
size = std::max(size, member_size);
|
||||
}
|
||||
return size;
|
||||
}
|
||||
case spv::Op::OpTypeArray:
|
||||
return GetLargestScalarType(inst->GetOperandAs<uint32_t>(1));
|
||||
case spv::Op::OpTypeVector:
|
||||
return GetLargestScalarType(inst->GetOperandAs<uint32_t>(1));
|
||||
default:
|
||||
return GetBitWidth(id) / 8;
|
||||
}
|
||||
}
|
||||
|
||||
bool ValidationState_t::IsAccelerationStructureType(uint32_t id) const {
|
||||
const Instruction* inst = FindDef(id);
|
||||
return inst && inst->opcode() == spv::Op::OpTypeAccelerationStructureKHR;
|
||||
@@ -1411,6 +1502,11 @@ bool ValidationState_t::IsUnsignedIntCooperativeVectorNVType(
|
||||
return IsUnsignedIntScalarType(FindDef(id)->word(2));
|
||||
}
|
||||
|
||||
bool ValidationState_t::IsTensorType(uint32_t id) const {
|
||||
const Instruction* inst = FindDef(id);
|
||||
return inst && inst->opcode() == spv::Op::OpTypeTensorARM;
|
||||
}
|
||||
|
||||
spv_result_t ValidationState_t::CooperativeMatrixShapesMatch(
|
||||
const Instruction* inst, uint32_t result_type_id, uint32_t m2,
|
||||
bool is_conversion, bool swap_row_col) {
|
||||
@@ -1445,8 +1541,7 @@ spv_result_t ValidationState_t::CooperativeMatrixShapesMatch(
|
||||
|
||||
if (m1_is_const_int32 && m2_is_const_int32 && m1_value != m2_value) {
|
||||
return diag(SPV_ERROR_INVALID_DATA, inst)
|
||||
<< "Expected scopes of Matrix and Result Type to be "
|
||||
<< "identical";
|
||||
<< "Expected scopes of Matrix and Result Type to be " << "identical";
|
||||
}
|
||||
|
||||
std::tie(m1_is_int32, m1_is_const_int32, m1_value) =
|
||||
@@ -1949,6 +2044,14 @@ bool ValidationState_t::IsValidStorageClass(
|
||||
return true;
|
||||
}
|
||||
|
||||
std::string ValidationState_t::MissingFeature(const std::string& feature,
|
||||
const std::string& cmdline,
|
||||
bool hint) const {
|
||||
return "\nThis is " + (hint ? std::string("may be ") : "") +
|
||||
"allowed if you enable the " + feature + " (or use the " + cmdline +
|
||||
" command line flag)";
|
||||
}
|
||||
|
||||
#define VUID_WRAP(vuid) "[" #vuid "] "
|
||||
|
||||
// Currently no 2 VUID share the same id, so no need for |reference|
|
||||
@@ -2211,6 +2314,8 @@ std::string ValidationState_t::VkErrorID(uint32_t id,
|
||||
return VUID_WRAP(VUID-Position-Position-04321);
|
||||
case 4330:
|
||||
return VUID_WRAP(VUID-PrimitiveId-PrimitiveId-04330);
|
||||
case 4333:
|
||||
return VUID_WRAP(VUID-PrimitiveId-Fragment-04333);
|
||||
case 4334:
|
||||
return VUID_WRAP(VUID-PrimitiveId-PrimitiveId-04334);
|
||||
case 4336:
|
||||
@@ -2399,10 +2504,6 @@ std::string ValidationState_t::VkErrorID(uint32_t id,
|
||||
return VUID_WRAP(VUID-StandaloneSpirv-None-04644);
|
||||
case 4645:
|
||||
return VUID_WRAP(VUID-StandaloneSpirv-None-04645);
|
||||
case 10609:
|
||||
return VUID_WRAP(VUID-StandaloneSpirv-OpControlBarrier-10609);
|
||||
case 4650:
|
||||
return VUID_WRAP(VUID-StandaloneSpirv-OpControlBarrier-04650);
|
||||
case 4651:
|
||||
return VUID_WRAP(VUID-StandaloneSpirv-OpVariable-04651);
|
||||
case 4652:
|
||||
@@ -2469,14 +2570,6 @@ std::string ValidationState_t::VkErrorID(uint32_t id,
|
||||
return VUID_WRAP(VUID-StandaloneSpirv-PhysicalStorageBuffer64-04710);
|
||||
case 4711:
|
||||
return VUID_WRAP(VUID-StandaloneSpirv-OpTypeForwardPointer-04711);
|
||||
case 4730:
|
||||
return VUID_WRAP(VUID-StandaloneSpirv-OpAtomicStore-04730);
|
||||
case 4731:
|
||||
return VUID_WRAP(VUID-StandaloneSpirv-OpAtomicLoad-04731);
|
||||
case 4732:
|
||||
return VUID_WRAP(VUID-StandaloneSpirv-OpMemoryBarrier-04732);
|
||||
case 4733:
|
||||
return VUID_WRAP(VUID-StandaloneSpirv-OpMemoryBarrier-04733);
|
||||
case 4734:
|
||||
return VUID_WRAP(VUID-StandaloneSpirv-OpVariable-04734);
|
||||
case 4744:
|
||||
@@ -2485,8 +2578,6 @@ std::string ValidationState_t::VkErrorID(uint32_t id,
|
||||
return VUID_WRAP(VUID-StandaloneSpirv-OpImage-04777);
|
||||
case 4780:
|
||||
return VUID_WRAP(VUID-StandaloneSpirv-Result-04780);
|
||||
case 4781:
|
||||
return VUID_WRAP(VUID-StandaloneSpirv-Base-04781);
|
||||
case 4915:
|
||||
return VUID_WRAP(VUID-StandaloneSpirv-Location-04915);
|
||||
case 4916:
|
||||
@@ -2511,6 +2602,8 @@ std::string ValidationState_t::VkErrorID(uint32_t id,
|
||||
return VUID_WRAP(VUID-StandaloneSpirv-Flat-06202);
|
||||
case 6214:
|
||||
return VUID_WRAP(VUID-StandaloneSpirv-OpTypeImage-06214);
|
||||
case 6314:
|
||||
return VUID_WRAP(VUID-StandaloneSpirv-PhysicalStorageBuffer64-06314);
|
||||
case 6491:
|
||||
return VUID_WRAP(VUID-StandaloneSpirv-DescriptorSet-06491);
|
||||
case 6671:
|
||||
@@ -2621,6 +2714,10 @@ std::string ValidationState_t::VkErrorID(uint32_t id,
|
||||
return VUID_WRAP(VUID-StandaloneSpirv-OpEntryPoint-09658);
|
||||
case 9659:
|
||||
return VUID_WRAP(VUID-StandaloneSpirv-OpEntryPoint-09659);
|
||||
case 10151:
|
||||
return VUID_WRAP(VUID-StandaloneSpirv-DerivativeGroupQuadsKHR-10151);
|
||||
case 10152:
|
||||
return VUID_WRAP(VUID-StandaloneSpirv-DerivativeGroupLinearKHR-10152);
|
||||
case 10213:
|
||||
// This use to be a standalone, but maintenance8 will set allow_offset_texture_operand now
|
||||
return VUID_WRAP(VUID-RuntimeSpirv-Offset-10213);
|
||||
@@ -2628,10 +2725,71 @@ std::string ValidationState_t::VkErrorID(uint32_t id,
|
||||
return VUID_WRAP(VUID-StandaloneSpirv-OpTypeFloat-10370);
|
||||
case 10583:
|
||||
return VUID_WRAP(VUID-StandaloneSpirv-Component-10583);
|
||||
case 10589:
|
||||
return VUID_WRAP(VUID-CullPrimitiveEXT-CullPrimitiveEXT-10589);
|
||||
case 10590:
|
||||
return VUID_WRAP(VUID-CullPrimitiveEXT-CullPrimitiveEXT-10590);
|
||||
case 10591:
|
||||
return VUID_WRAP(VUID-CullPrimitiveEXT-CullPrimitiveEXT-10591);
|
||||
case 10592:
|
||||
return VUID_WRAP(VUID-Layer-Layer-10592);
|
||||
case 10593:
|
||||
return VUID_WRAP(VUID-Layer-Layer-10593);
|
||||
case 10594:
|
||||
return VUID_WRAP(VUID-Layer-Layer-10594);
|
||||
case 10598:
|
||||
return VUID_WRAP(VUID-PrimitiveShadingRateKHR-PrimitiveShadingRateKHR-10598);
|
||||
case 10599:
|
||||
return VUID_WRAP(VUID-PrimitiveShadingRateKHR-PrimitiveShadingRateKHR-10599);
|
||||
case 10600:
|
||||
return VUID_WRAP(VUID-PrimitiveShadingRateKHR-PrimitiveShadingRateKHR-10600);
|
||||
case 10601:
|
||||
return VUID_WRAP(VUID-ViewportIndex-ViewportIndex-10601);
|
||||
case 10602:
|
||||
return VUID_WRAP(VUID-ViewportIndex-ViewportIndex-10602);
|
||||
case 10603:
|
||||
return VUID_WRAP(VUID-ViewportIndex-ViewportIndex-10603);
|
||||
case 10684:
|
||||
return VUID_WRAP(VUID-StandaloneSpirv-None-10684);
|
||||
case 10685:
|
||||
return VUID_WRAP(VUID-StandaloneSpirv-None-10685);
|
||||
case 10824:
|
||||
// This use to be a standalone, but maintenance9 will set allow_vulkan_32_bit_bitwise now
|
||||
return VUID_WRAP(VUID-RuntimeSpirv-None-10824);
|
||||
case 10865:
|
||||
return VUID_WRAP(VUID-StandaloneSpirv-MemorySemantics-10865);
|
||||
case 10866:
|
||||
return VUID_WRAP(VUID-StandaloneSpirv-MemorySemantics-10866);
|
||||
case 10867:
|
||||
return VUID_WRAP(VUID-StandaloneSpirv-MemorySemantics-10867);
|
||||
case 10868:
|
||||
return VUID_WRAP(VUID-StandaloneSpirv-MemorySemantics-10868);
|
||||
case 10869:
|
||||
return VUID_WRAP(VUID-StandaloneSpirv-MemorySemantics-10869);
|
||||
case 10870:
|
||||
return VUID_WRAP(VUID-StandaloneSpirv-MemorySemantics-10870);
|
||||
case 10871:
|
||||
return VUID_WRAP(VUID-StandaloneSpirv-MemorySemantics-10871);
|
||||
case 10872:
|
||||
return VUID_WRAP(VUID-StandaloneSpirv-MemorySemantics-10872);
|
||||
case 10873:
|
||||
return VUID_WRAP(VUID-StandaloneSpirv-MemorySemantics-10873);
|
||||
case 10874:
|
||||
return VUID_WRAP(VUID-StandaloneSpirv-MemorySemantics-10874);
|
||||
case 10875:
|
||||
return VUID_WRAP(VUID-StandaloneSpirv-UnequalMemorySemantics-10875);
|
||||
case 10876:
|
||||
return VUID_WRAP(VUID-StandaloneSpirv-UnequalMemorySemantics-10876);
|
||||
case 10877:
|
||||
return VUID_WRAP(VUID-StandaloneSpirv-UnequalMemorySemantics-10877);
|
||||
case 10878:
|
||||
return VUID_WRAP(VUID-StandaloneSpirv-UnequalMemorySemantics-10878);
|
||||
case 10879:
|
||||
return VUID_WRAP(VUID-StandaloneSpirv-UnequalMemorySemantics-10879);
|
||||
case 10880:
|
||||
return VUID_WRAP(VUID-StandaloneSpirv-TessLevelInner-10880);
|
||||
case 11167:
|
||||
return VUID_WRAP(VUID-StandaloneSpirv-OpUntypedVariableKHR-11167);
|
||||
default:
|
||||
return ""; // unknown id
|
||||
}
|
||||
|
||||
110
3rdparty/spirv-tools/source/val/validation_state.h
vendored
110
3rdparty/spirv-tools/source/val/validation_state.h
vendored
@@ -50,6 +50,7 @@ enum ModuleLayoutSection {
|
||||
kLayoutExtInstImport, /// < Section 2.4 #3
|
||||
kLayoutMemoryModel, /// < Section 2.4 #4
|
||||
kLayoutSamplerImageAddressMode, /// < Section 2.4 #5
|
||||
/// (SPV_NV_bindless_texture)
|
||||
kLayoutEntryPoint, /// < Section 2.4 #6
|
||||
kLayoutExecutionMode, /// < Section 2.4 #7
|
||||
kLayoutDebug1, /// < Section 2.4 #8 > 1
|
||||
@@ -58,7 +59,18 @@ enum ModuleLayoutSection {
|
||||
kLayoutAnnotations, /// < Section 2.4 #9
|
||||
kLayoutTypes, /// < Section 2.4 #10
|
||||
kLayoutFunctionDeclarations, /// < Section 2.4 #11
|
||||
kLayoutFunctionDefinitions /// < Section 2.4 #12
|
||||
kLayoutFunctionDefinitions, /// < Section 2.4 #12
|
||||
kLayoutGraphDefinitions /// < Section 2.4 #13 (SPV_ARM_graph)
|
||||
};
|
||||
|
||||
/// This enum represents the regions of a graph definition. The relative
|
||||
/// ordering of the values is significant.
|
||||
enum GraphDefinitionRegion {
|
||||
kGraphDefinitionOutside,
|
||||
kGraphDefinitionBegin,
|
||||
kGraphDefinitionInputs,
|
||||
kGraphDefinitionBody,
|
||||
kGraphDefinitionOutputs,
|
||||
};
|
||||
|
||||
/// This class manages the state of the SPIR-V validation as it is being parsed.
|
||||
@@ -213,6 +225,9 @@ class ValidationState_t {
|
||||
/// instruction
|
||||
bool in_block() const;
|
||||
|
||||
/// Returns the region of a graph definition we are in.
|
||||
GraphDefinitionRegion graph_definition_region() const;
|
||||
|
||||
struct EntryPointDescription {
|
||||
std::string name;
|
||||
std::vector<uint32_t> interfaces;
|
||||
@@ -313,6 +328,16 @@ class ValidationState_t {
|
||||
/// ComputeFunctionToEntryPointMapping.
|
||||
void ComputeRecursiveEntryPoints();
|
||||
|
||||
/// Registers |id| as a graph entry point.
|
||||
void RegisterGraphEntryPoint(const uint32_t id) {
|
||||
graph_entry_points_.push_back(id);
|
||||
}
|
||||
|
||||
/// Returns a list of graph entry point graph ids
|
||||
const std::vector<uint32_t>& graph_entry_points() const {
|
||||
return graph_entry_points_;
|
||||
}
|
||||
|
||||
/// Returns all the entry points that can call |func|.
|
||||
const std::vector<uint32_t>& FunctionEntryPoints(uint32_t func) const;
|
||||
|
||||
@@ -350,6 +375,9 @@ class ValidationState_t {
|
||||
/// Register a function end instruction
|
||||
spv_result_t RegisterFunctionEnd();
|
||||
|
||||
/// Sets the region of a graph definition we're in.
|
||||
void SetGraphDefinitionRegion(GraphDefinitionRegion region);
|
||||
|
||||
/// Returns true if the capability is enabled in the module.
|
||||
bool HasCapability(spv::Capability cap) const {
|
||||
return module_capabilities_.contains(cap);
|
||||
@@ -632,23 +660,26 @@ class ValidationState_t {
|
||||
bool GetStructMemberTypes(uint32_t struct_type_id,
|
||||
std::vector<uint32_t>* member_types) const;
|
||||
|
||||
// Returns true iff |id| is a type corresponding to the name of the function.
|
||||
// Returns true if |id| is a type corresponding to the name of the function.
|
||||
// Only works for types not for objects.
|
||||
bool IsVoidType(uint32_t id) const;
|
||||
bool IsScalarType(uint32_t id) const;
|
||||
bool IsBfloat16ScalarType(uint32_t id) const;
|
||||
bool IsBfloat16VectorType(uint32_t id) const;
|
||||
bool IsBfloat16CoopMatType(uint32_t id) const;
|
||||
bool IsBfloat16Type(uint32_t id) const;
|
||||
bool IsFP8ScalarType(uint32_t id) const;
|
||||
bool IsFP8VectorType(uint32_t id) const;
|
||||
bool IsFP8ScalarOrVectorType(uint32_t id) const;
|
||||
bool IsFP8CoopMatType(uint32_t id) const;
|
||||
bool IsFP8Type(uint32_t id) const;
|
||||
bool IsFloatScalarType(uint32_t id) const;
|
||||
bool IsFloatArrayType(uint32_t id) const;
|
||||
bool IsFloatVectorType(uint32_t id) const;
|
||||
bool IsFloat16Vector2Or4Type(uint32_t id) const;
|
||||
bool IsFloatScalarOrVectorType(uint32_t id) const;
|
||||
bool IsFloatMatrixType(uint32_t id) const;
|
||||
bool IsIntScalarType(uint32_t id) const;
|
||||
bool IsIntArrayType(uint32_t id, uint64_t length = 0) const;
|
||||
bool IsIntScalarType(uint32_t id, uint32_t width = 0) const;
|
||||
bool IsIntScalarTypeWithSignedness(uint32_t id, uint32_t signedness) const;
|
||||
bool IsIntVectorType(uint32_t id) const;
|
||||
bool IsIntScalarOrVectorType(uint32_t id) const;
|
||||
bool IsUnsignedIntScalarType(uint32_t id) const;
|
||||
@@ -675,6 +706,36 @@ class ValidationState_t {
|
||||
bool IsFloatCooperativeVectorNVType(uint32_t id) const;
|
||||
bool IsIntCooperativeVectorNVType(uint32_t id) const;
|
||||
bool IsUnsignedIntCooperativeVectorNVType(uint32_t id) const;
|
||||
bool IsTensorType(uint32_t id) const;
|
||||
// When |length| is not 0, return true only if the array length is equal to
|
||||
// |length| and the array length is not defined by a specialization constant.
|
||||
bool IsArrayType(uint32_t id, uint64_t length = 0) const;
|
||||
bool IsIntArrayType(uint32_t id, uint64_t length = 0) const;
|
||||
template <unsigned int N>
|
||||
bool IsIntNOrFP32OrFP16(unsigned int type_id) {
|
||||
return this->ContainsType(
|
||||
type_id,
|
||||
[](const Instruction* inst) {
|
||||
if (inst->opcode() == spv::Op::OpTypeInt) {
|
||||
return inst->GetOperandAs<uint32_t>(1) == N;
|
||||
} else if (inst->opcode() == spv::Op::OpTypeFloat) {
|
||||
if (inst->operands().size() > 2) {
|
||||
// Not IEEE
|
||||
return false;
|
||||
}
|
||||
auto width = inst->GetOperandAs<uint32_t>(1);
|
||||
return width == 32 || width == 16;
|
||||
}
|
||||
return false;
|
||||
},
|
||||
/* traverse_all_types = */ false);
|
||||
}
|
||||
|
||||
// Will walk the type to find the largest scalar value size.
|
||||
// Returns value is in bytes.
|
||||
// This is designed to pass in the %type from a PSB pointer
|
||||
// %ptr = OpTypePointer PhysicalStorageBuffer %type
|
||||
uint32_t GetLargestScalarType(uint32_t id) const;
|
||||
|
||||
// Returns true if |id| is a type id that contains |type| (or integer or
|
||||
// floating point type) of |width| bits.
|
||||
@@ -715,6 +776,17 @@ class ValidationState_t {
|
||||
bool GetPointerTypeInfo(uint32_t id, uint32_t* data_type,
|
||||
spv::StorageClass* storage_class) const;
|
||||
|
||||
// Returns the value assocated with id via 'value' if id is an OpConstant
|
||||
template <typename T>
|
||||
bool GetConstantValueAs(unsigned int id, T& value) {
|
||||
const auto inst = FindDef(id);
|
||||
uint64_t ui64_val = 0u;
|
||||
bool status = (inst && spvOpcodeIsConstant(inst->opcode()) &&
|
||||
EvalConstantValUint64(id, &ui64_val));
|
||||
if (status == true) value = static_cast<T>(ui64_val);
|
||||
return status;
|
||||
}
|
||||
|
||||
// Is the ID the type of a pointer to a uniform block: Block-decorated struct
|
||||
// in uniform storage class? The result is only valid after internal method
|
||||
// CheckDecorationsOfBuffers has been called.
|
||||
@@ -772,6 +844,16 @@ class ValidationState_t {
|
||||
pointer_to_storage_image_.insert(type_id);
|
||||
}
|
||||
|
||||
// Is the ID the type of a pointer to a tensor? That is, the pointee
|
||||
// type is a tensor type.
|
||||
bool IsPointerToTensor(uint32_t type_id) const {
|
||||
return pointer_to_tensor_.find(type_id) != pointer_to_tensor_.cend();
|
||||
}
|
||||
// Save the ID of a pointer to a tensor.
|
||||
void RegisterPointerToTensor(uint32_t type_id) {
|
||||
pointer_to_tensor_.insert(type_id);
|
||||
}
|
||||
|
||||
// Tries to evaluate a any scalar integer OpConstant as uint64.
|
||||
// OpConstantNull is defined as zero for scalar int (will return true)
|
||||
// OpSpecConstant* return false since their values cannot be relied upon
|
||||
@@ -844,6 +926,12 @@ class ValidationState_t {
|
||||
// Validates the storage class for the target environment.
|
||||
bool IsValidStorageClass(spv::StorageClass storage_class) const;
|
||||
|
||||
// Helps formulate a mesesage to user that setting one of the validator
|
||||
// options might make their SPIR-V actually valid The |hint| option is because
|
||||
// some checks are intertwined with each other, so hard to give confirmation
|
||||
std::string MissingFeature(const std::string& feature,
|
||||
const std::string& cmdline, bool hint) const;
|
||||
|
||||
// Takes a Vulkan Valid Usage ID (VUID) as |id| and optional |reference| and
|
||||
// will return a non-empty string only if ID is known and targeting Vulkan.
|
||||
// VUIDs are found in the Vulkan-Docs repo in the form "[[VUID-ref-ref-id]]"
|
||||
@@ -939,6 +1027,9 @@ class ValidationState_t {
|
||||
/// graph that recurses.
|
||||
std::set<uint32_t> recursive_entry_points_;
|
||||
|
||||
/// IDs that are graph entry points, ie, arguments to OpGraphEntryPointARM.
|
||||
std::vector<uint32_t> graph_entry_points_;
|
||||
|
||||
/// Functions IDs that are target of OpFunctionCall.
|
||||
std::unordered_set<uint32_t> function_call_targets_;
|
||||
|
||||
@@ -981,9 +1072,13 @@ class ValidationState_t {
|
||||
/// bit width of sampler/image type variables. Valid values are 32 and 64
|
||||
uint32_t sampler_image_addressing_mode_;
|
||||
|
||||
/// NOTE: See correspoding getter functions
|
||||
/// NOTE: See corresponding getter functions
|
||||
bool in_function_;
|
||||
|
||||
/// Where in a graph definition we are
|
||||
/// NOTE: See corresponding getter/setter functions
|
||||
GraphDefinitionRegion graph_definition_region_;
|
||||
|
||||
/// The state of optional features. These are determined by capabilities
|
||||
/// declared by the module and the environment.
|
||||
Feature features_;
|
||||
@@ -1030,6 +1125,9 @@ class ValidationState_t {
|
||||
// The IDs of types of pointers to storage images. This is populated in the
|
||||
// TypePass.
|
||||
std::unordered_set<uint32_t> pointer_to_storage_image_;
|
||||
// The IDs of types of pointers to tensors. This is populated in the
|
||||
// TypePass.
|
||||
std::unordered_set<uint32_t> pointer_to_tensor_;
|
||||
|
||||
/// Maps ids to friendly names.
|
||||
std::unique_ptr<spvtools::FriendlyNameMapper> friendly_mapper_;
|
||||
|
||||
Reference in New Issue
Block a user