diff --git a/3rdparty/spirv-tools/include/generated/build-version.inc b/3rdparty/spirv-tools/include/generated/build-version.inc index b052f188f..d4c98e877 100644 --- a/3rdparty/spirv-tools/include/generated/build-version.inc +++ b/3rdparty/spirv-tools/include/generated/build-version.inc @@ -1 +1 @@ -"v2022.3-dev", "SPIRV-Tools v2022.3-dev 560669f6c0e19daf8f29e1f085599f0765e4ee35" +"v2022.3-dev", "SPIRV-Tools v2022.3-dev b8091498a3d8f2f7a46df2009d7c340eef1939b6" diff --git a/3rdparty/spirv-tools/include/spirv-tools/optimizer.hpp b/3rdparty/spirv-tools/include/spirv-tools/optimizer.hpp index df830d7f8..949735608 100644 --- a/3rdparty/spirv-tools/include/spirv-tools/optimizer.hpp +++ b/3rdparty/spirv-tools/include/spirv-tools/optimizer.hpp @@ -903,6 +903,11 @@ Optimizer::PassToken CreateConvertToSampledImagePass( const std::vector& descriptor_set_binding_pairs); +// Create an interface-variable-scalar-replacement pass that replaces array or +// matrix interface variables with a series of scalar or vector interface +// variables. For example, it replaces `float3 foo[2]` with `float3 foo0, foo1`. +Optimizer::PassToken CreateInterfaceVariableScalarReplacementPass(); + // Creates a remove-dont-inline pass to remove the |DontInline| function control // from every function in the module. This is useful if you want the inliner to // inline these functions some reason. diff --git a/3rdparty/spirv-tools/source/opt/interface_var_sroa.cpp b/3rdparty/spirv-tools/source/opt/interface_var_sroa.cpp new file mode 100644 index 000000000..58ed897c1 --- /dev/null +++ b/3rdparty/spirv-tools/source/opt/interface_var_sroa.cpp @@ -0,0 +1,964 @@ +// Copyright (c) 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/opt/interface_var_sroa.h" + +#include + +#include "source/opt/decoration_manager.h" +#include "source/opt/def_use_manager.h" +#include "source/opt/function.h" +#include "source/opt/log.h" +#include "source/opt/type_manager.h" +#include "source/util/make_unique.h" + +const static uint32_t kOpDecorateDecorationInOperandIndex = 1; +const static uint32_t kOpDecorateLiteralInOperandIndex = 2; +const static uint32_t kOpEntryPointInOperandInterface = 3; +const static uint32_t kOpVariableStorageClassInOperandIndex = 0; +const static uint32_t kOpTypeArrayElemTypeInOperandIndex = 0; +const static uint32_t kOpTypeArrayLengthInOperandIndex = 1; +const static uint32_t kOpTypeMatrixColCountInOperandIndex = 1; +const static uint32_t kOpTypeMatrixColTypeInOperandIndex = 0; +const static uint32_t kOpTypePtrTypeInOperandIndex = 1; +const static uint32_t kOpConstantValueInOperandIndex = 0; + +namespace spvtools { +namespace opt { +namespace { + +// Get the length of the OpTypeArray |array_type|. +uint32_t GetArrayLength(analysis::DefUseManager* def_use_mgr, + Instruction* array_type) { + assert(array_type->opcode() == SpvOpTypeArray); + uint32_t const_int_id = + array_type->GetSingleWordInOperand(kOpTypeArrayLengthInOperandIndex); + Instruction* array_length_inst = def_use_mgr->GetDef(const_int_id); + assert(array_length_inst->opcode() == SpvOpConstant); + return array_length_inst->GetSingleWordInOperand( + kOpConstantValueInOperandIndex); +} + +// Get the element type instruction of the OpTypeArray |array_type|. +Instruction* GetArrayElementType(analysis::DefUseManager* def_use_mgr, + Instruction* array_type) { + assert(array_type->opcode() == SpvOpTypeArray); + uint32_t elem_type_id = + array_type->GetSingleWordInOperand(kOpTypeArrayElemTypeInOperandIndex); + return def_use_mgr->GetDef(elem_type_id); +} + +// Get the column type instruction of the OpTypeMatrix |matrix_type|. +Instruction* GetMatrixColumnType(analysis::DefUseManager* def_use_mgr, + Instruction* matrix_type) { + assert(matrix_type->opcode() == SpvOpTypeMatrix); + uint32_t column_type_id = + matrix_type->GetSingleWordInOperand(kOpTypeMatrixColTypeInOperandIndex); + return def_use_mgr->GetDef(column_type_id); +} + +// Traverses the component type of OpTypeArray or OpTypeMatrix. Repeats it +// |depth_to_component| times recursively and returns the component type. +// |type_id| is the result id of the OpTypeArray or OpTypeMatrix instruction. +uint32_t GetComponentTypeOfArrayMatrix(analysis::DefUseManager* def_use_mgr, + uint32_t type_id, + uint32_t depth_to_component) { + if (depth_to_component == 0) return type_id; + + Instruction* type_inst = def_use_mgr->GetDef(type_id); + if (type_inst->opcode() == SpvOpTypeArray) { + uint32_t elem_type_id = + type_inst->GetSingleWordInOperand(kOpTypeArrayElemTypeInOperandIndex); + return GetComponentTypeOfArrayMatrix(def_use_mgr, elem_type_id, + depth_to_component - 1); + } + + assert(type_inst->opcode() == SpvOpTypeMatrix); + uint32_t column_type_id = + type_inst->GetSingleWordInOperand(kOpTypeMatrixColTypeInOperandIndex); + return GetComponentTypeOfArrayMatrix(def_use_mgr, column_type_id, + depth_to_component - 1); +} + +// Creates an OpDecorate instruction whose Target is |var_id| and Decoration is +// |decoration|. Adds |literal| as an extra operand of the instruction. +void CreateDecoration(analysis::DecorationManager* decoration_mgr, + uint32_t var_id, SpvDecoration decoration, + uint32_t literal) { + std::vector operands({ + {spv_operand_type_t::SPV_OPERAND_TYPE_ID, {var_id}}, + {spv_operand_type_t::SPV_OPERAND_TYPE_DECORATION, + {static_cast(decoration)}}, + {spv_operand_type_t::SPV_OPERAND_TYPE_LITERAL_INTEGER, {literal}}, + }); + decoration_mgr->AddDecoration(SpvOpDecorate, std::move(operands)); +} + +// Replaces load instructions with composite construct instructions in all the +// users of the loads. |loads_to_composites| is the mapping from each load to +// its corresponding OpCompositeConstruct. +void ReplaceLoadWithCompositeConstruct( + IRContext* context, + const std::unordered_map& loads_to_composites) { + for (const auto& load_and_composite : loads_to_composites) { + Instruction* load = load_and_composite.first; + Instruction* composite_construct = load_and_composite.second; + + std::vector users; + context->get_def_use_mgr()->ForEachUse( + load, [&users, composite_construct](Instruction* user, uint32_t index) { + user->GetOperand(index).words[0] = composite_construct->result_id(); + users.push_back(user); + }); + + for (Instruction* user : users) + context->get_def_use_mgr()->AnalyzeInstUse(user); + } +} + +// Returns the storage class of the instruction |var|. +SpvStorageClass GetStorageClass(Instruction* var) { + return static_cast( + var->GetSingleWordInOperand(kOpVariableStorageClassInOperandIndex)); +} + +} // namespace + +bool InterfaceVariableScalarReplacement::HasExtraArrayness( + Instruction& entry_point, Instruction* var) { + SpvExecutionModel execution_model = + static_cast(entry_point.GetSingleWordInOperand(0)); + if (execution_model != SpvExecutionModelTessellationEvaluation && + execution_model != SpvExecutionModelTessellationControl) { + return false; + } + if (!context()->get_decoration_mgr()->HasDecoration(var->result_id(), + SpvDecorationPatch)) { + if (execution_model == SpvExecutionModelTessellationControl) return true; + return GetStorageClass(var) != SpvStorageClassOutput; + } + return false; +} + +bool InterfaceVariableScalarReplacement:: + CheckExtraArraynessConflictBetweenEntries(Instruction* interface_var, + bool has_extra_arrayness) { + if (has_extra_arrayness) { + return !ReportErrorIfHasNoExtraArraynessForOtherEntry(interface_var); + } + return !ReportErrorIfHasExtraArraynessForOtherEntry(interface_var); +} + +bool InterfaceVariableScalarReplacement::GetVariableLocation( + Instruction* var, uint32_t* location) { + return !context()->get_decoration_mgr()->WhileEachDecoration( + var->result_id(), SpvDecorationLocation, + [location](const Instruction& inst) { + *location = + inst.GetSingleWordInOperand(kOpDecorateLiteralInOperandIndex); + return false; + }); +} + +bool InterfaceVariableScalarReplacement::GetVariableComponent( + Instruction* var, uint32_t* component) { + return !context()->get_decoration_mgr()->WhileEachDecoration( + var->result_id(), SpvDecorationComponent, + [component](const Instruction& inst) { + *component = + inst.GetSingleWordInOperand(kOpDecorateLiteralInOperandIndex); + return false; + }); +} + +std::vector +InterfaceVariableScalarReplacement::CollectInterfaceVariables( + Instruction& entry_point) { + std::vector interface_vars; + for (uint32_t i = kOpEntryPointInOperandInterface; + i < entry_point.NumInOperands(); ++i) { + Instruction* interface_var = context()->get_def_use_mgr()->GetDef( + entry_point.GetSingleWordInOperand(i)); + assert(interface_var->opcode() == SpvOpVariable); + + SpvStorageClass storage_class = GetStorageClass(interface_var); + if (storage_class != SpvStorageClassInput && + storage_class != SpvStorageClassOutput) { + continue; + } + + interface_vars.push_back(interface_var); + } + return interface_vars; +} + +void InterfaceVariableScalarReplacement::KillInstructionAndUsers( + Instruction* inst) { + if (inst->opcode() == SpvOpEntryPoint) { + return; + } + if (inst->opcode() != SpvOpAccessChain) { + context()->KillInst(inst); + return; + } + context()->get_def_use_mgr()->ForEachUser( + inst, [this](Instruction* user) { KillInstructionAndUsers(user); }); + context()->KillInst(inst); +} + +void InterfaceVariableScalarReplacement::KillInstructionsAndUsers( + const std::vector& insts) { + for (Instruction* inst : insts) { + KillInstructionAndUsers(inst); + } +} + +void InterfaceVariableScalarReplacement::KillLocationAndComponentDecorations( + uint32_t var_id) { + context()->get_decoration_mgr()->RemoveDecorationsFrom( + var_id, [](const Instruction& inst) { + uint32_t decoration = + inst.GetSingleWordInOperand(kOpDecorateDecorationInOperandIndex); + return decoration == SpvDecorationLocation || + decoration == SpvDecorationComponent; + }); +} + +bool InterfaceVariableScalarReplacement::ReplaceInterfaceVariableWithScalars( + Instruction* interface_var, Instruction* interface_var_type, + uint32_t location, uint32_t component, uint32_t extra_array_length) { + NestedCompositeComponents scalar_interface_vars = + CreateScalarInterfaceVarsForReplacement(interface_var_type, + GetStorageClass(interface_var), + extra_array_length); + + AddLocationAndComponentDecorations(scalar_interface_vars, &location, + component); + KillLocationAndComponentDecorations(interface_var->result_id()); + + if (!ReplaceInterfaceVarWith(interface_var, extra_array_length, + scalar_interface_vars)) { + return false; + } + + context()->KillInst(interface_var); + return true; +} + +bool InterfaceVariableScalarReplacement::ReplaceInterfaceVarWith( + Instruction* interface_var, uint32_t extra_array_length, + const NestedCompositeComponents& scalar_interface_vars) { + std::vector users; + context()->get_def_use_mgr()->ForEachUser( + interface_var, [&users](Instruction* user) { users.push_back(user); }); + + std::vector interface_var_component_indices; + std::unordered_map loads_to_composites; + std::unordered_map + loads_for_access_chain_to_composites; + if (extra_array_length != 0) { + // Note that the extra arrayness is the first dimension of the array + // interface variable. + for (uint32_t index = 0; index < extra_array_length; ++index) { + std::unordered_map loads_to_component_values; + if (!ReplaceComponentsOfInterfaceVarWith( + interface_var, users, scalar_interface_vars, + interface_var_component_indices, &index, + &loads_to_component_values, + &loads_for_access_chain_to_composites)) { + return false; + } + AddComponentsToCompositesForLoads(loads_to_component_values, + &loads_to_composites, 0); + } + } else if (!ReplaceComponentsOfInterfaceVarWith( + interface_var, users, scalar_interface_vars, + interface_var_component_indices, nullptr, &loads_to_composites, + &loads_for_access_chain_to_composites)) { + return false; + } + + ReplaceLoadWithCompositeConstruct(context(), loads_to_composites); + ReplaceLoadWithCompositeConstruct(context(), + loads_for_access_chain_to_composites); + + KillInstructionsAndUsers(users); + return true; +} + +void InterfaceVariableScalarReplacement::AddLocationAndComponentDecorations( + const NestedCompositeComponents& vars, uint32_t* location, + uint32_t component) { + if (!vars.HasMultipleComponents()) { + uint32_t var_id = vars.GetComponentVariable()->result_id(); + CreateDecoration(context()->get_decoration_mgr(), var_id, + SpvDecorationLocation, *location); + CreateDecoration(context()->get_decoration_mgr(), var_id, + SpvDecorationComponent, component); + ++(*location); + return; + } + for (const auto& var : vars.GetComponents()) { + AddLocationAndComponentDecorations(var, location, component); + } +} + +bool InterfaceVariableScalarReplacement::ReplaceComponentsOfInterfaceVarWith( + Instruction* interface_var, + const std::vector& interface_var_users, + const NestedCompositeComponents& scalar_interface_vars, + std::vector& interface_var_component_indices, + const uint32_t* extra_array_index, + std::unordered_map* loads_to_composites, + std::unordered_map* + loads_for_access_chain_to_composites) { + if (!scalar_interface_vars.HasMultipleComponents()) { + for (Instruction* interface_var_user : interface_var_users) { + if (!ReplaceComponentOfInterfaceVarWith( + interface_var, interface_var_user, + scalar_interface_vars.GetComponentVariable(), + interface_var_component_indices, extra_array_index, + loads_to_composites, loads_for_access_chain_to_composites)) { + return false; + } + } + return true; + } + return ReplaceMultipleComponentsOfInterfaceVarWith( + interface_var, interface_var_users, scalar_interface_vars.GetComponents(), + interface_var_component_indices, extra_array_index, loads_to_composites, + loads_for_access_chain_to_composites); +} + +bool InterfaceVariableScalarReplacement:: + ReplaceMultipleComponentsOfInterfaceVarWith( + Instruction* interface_var, + const std::vector& interface_var_users, + const std::vector& components, + std::vector& interface_var_component_indices, + const uint32_t* extra_array_index, + std::unordered_map* loads_to_composites, + std::unordered_map* + loads_for_access_chain_to_composites) { + for (uint32_t i = 0; i < components.size(); ++i) { + interface_var_component_indices.push_back(i); + std::unordered_map loads_to_component_values; + std::unordered_map + loads_for_access_chain_to_component_values; + if (!ReplaceComponentsOfInterfaceVarWith( + interface_var, interface_var_users, components[i], + interface_var_component_indices, extra_array_index, + &loads_to_component_values, + &loads_for_access_chain_to_component_values)) { + return false; + } + interface_var_component_indices.pop_back(); + + uint32_t depth_to_component = + static_cast(interface_var_component_indices.size()); + AddComponentsToCompositesForLoads( + loads_for_access_chain_to_component_values, + loads_for_access_chain_to_composites, depth_to_component); + if (extra_array_index) ++depth_to_component; + AddComponentsToCompositesForLoads(loads_to_component_values, + loads_to_composites, depth_to_component); + } + return true; +} + +bool InterfaceVariableScalarReplacement::ReplaceComponentOfInterfaceVarWith( + Instruction* interface_var, Instruction* interface_var_user, + Instruction* scalar_var, + const std::vector& interface_var_component_indices, + const uint32_t* extra_array_index, + std::unordered_map* loads_to_component_values, + std::unordered_map* + loads_for_access_chain_to_component_values) { + SpvOp opcode = interface_var_user->opcode(); + if (opcode == SpvOpStore) { + uint32_t value_id = interface_var_user->GetSingleWordInOperand(1); + StoreComponentOfValueToScalarVar(value_id, interface_var_component_indices, + scalar_var, extra_array_index, + interface_var_user); + return true; + } + if (opcode == SpvOpLoad) { + Instruction* scalar_load = + LoadScalarVar(scalar_var, extra_array_index, interface_var_user); + loads_to_component_values->insert({interface_var_user, scalar_load}); + return true; + } + + // Copy OpName and annotation instructions only once. Therefore, we create + // them only for the first element of the extra array. + if (extra_array_index && *extra_array_index != 0) return true; + + if (opcode == SpvOpDecorateId || opcode == SpvOpDecorateString || + opcode == SpvOpDecorate) { + CloneAnnotationForVariable(interface_var_user, scalar_var->result_id()); + return true; + } + + if (opcode == SpvOpName) { + std::unique_ptr new_inst(interface_var_user->Clone(context())); + new_inst->SetInOperand(0, {scalar_var->result_id()}); + context()->AddDebug2Inst(std::move(new_inst)); + return true; + } + + if (opcode == SpvOpEntryPoint) { + return ReplaceInterfaceVarInEntryPoint(interface_var, interface_var_user, + scalar_var->result_id()); + } + + if (opcode == SpvOpAccessChain) { + ReplaceAccessChainWith(interface_var_user, interface_var_component_indices, + scalar_var, + loads_for_access_chain_to_component_values); + return true; + } + + std::string message("Unhandled instruction"); + message += "\n " + interface_var_user->PrettyPrint( + SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES); + message += + "\nfor interface variable scalar replacement\n " + + interface_var->PrettyPrint(SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES); + context()->consumer()(SPV_MSG_ERROR, "", {0, 0, 0}, message.c_str()); + return false; +} + +void InterfaceVariableScalarReplacement::UseBaseAccessChainForAccessChain( + Instruction* access_chain, Instruction* base_access_chain) { + assert(base_access_chain->opcode() == SpvOpAccessChain && + access_chain->opcode() == SpvOpAccessChain && + access_chain->GetSingleWordInOperand(0) == + base_access_chain->result_id()); + Instruction::OperandList new_operands; + for (uint32_t i = 0; i < base_access_chain->NumInOperands(); ++i) { + new_operands.emplace_back(base_access_chain->GetInOperand(i)); + } + for (uint32_t i = 1; i < access_chain->NumInOperands(); ++i) { + new_operands.emplace_back(access_chain->GetInOperand(i)); + } + access_chain->SetInOperands(std::move(new_operands)); +} + +Instruction* InterfaceVariableScalarReplacement::CreateAccessChainToVar( + uint32_t var_type_id, Instruction* var, + const std::vector& index_ids, Instruction* insert_before, + uint32_t* component_type_id) { + analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr(); + *component_type_id = GetComponentTypeOfArrayMatrix( + def_use_mgr, var_type_id, static_cast(index_ids.size())); + + uint32_t ptr_type_id = + GetPointerType(*component_type_id, GetStorageClass(var)); + + std::unique_ptr new_access_chain( + new Instruction(context(), SpvOpAccessChain, ptr_type_id, TakeNextId(), + std::initializer_list{ + {SPV_OPERAND_TYPE_ID, {var->result_id()}}})); + for (uint32_t index_id : index_ids) { + new_access_chain->AddOperand({SPV_OPERAND_TYPE_ID, {index_id}}); + } + + Instruction* inst = new_access_chain.get(); + def_use_mgr->AnalyzeInstDefUse(inst); + insert_before->InsertBefore(std::move(new_access_chain)); + return inst; +} + +Instruction* InterfaceVariableScalarReplacement::CreateAccessChainWithIndex( + uint32_t component_type_id, Instruction* var, uint32_t index, + Instruction* insert_before) { + uint32_t ptr_type_id = + GetPointerType(component_type_id, GetStorageClass(var)); + uint32_t index_id = context()->get_constant_mgr()->GetUIntConst(index); + std::unique_ptr new_access_chain( + new Instruction(context(), SpvOpAccessChain, ptr_type_id, TakeNextId(), + std::initializer_list{ + {SPV_OPERAND_TYPE_ID, {var->result_id()}}, + {SPV_OPERAND_TYPE_ID, {index_id}}, + })); + Instruction* inst = new_access_chain.get(); + context()->get_def_use_mgr()->AnalyzeInstDefUse(inst); + insert_before->InsertBefore(std::move(new_access_chain)); + return inst; +} + +void InterfaceVariableScalarReplacement::ReplaceAccessChainWith( + Instruction* access_chain, + const std::vector& interface_var_component_indices, + Instruction* scalar_var, + std::unordered_map* loads_to_component_values) { + std::vector indexes; + for (uint32_t i = 1; i < access_chain->NumInOperands(); ++i) { + indexes.push_back(access_chain->GetSingleWordInOperand(i)); + } + + // Note that we have a strong assumption that |access_chain| has only a single + // index that is for the extra arrayness. + context()->get_def_use_mgr()->ForEachUser( + access_chain, + [this, access_chain, &indexes, &interface_var_component_indices, + scalar_var, loads_to_component_values](Instruction* user) { + switch (user->opcode()) { + case SpvOpAccessChain: { + UseBaseAccessChainForAccessChain(user, access_chain); + ReplaceAccessChainWith(user, interface_var_component_indices, + scalar_var, loads_to_component_values); + return; + } + case SpvOpStore: { + uint32_t value_id = user->GetSingleWordInOperand(1); + StoreComponentOfValueToAccessChainToScalarVar( + value_id, interface_var_component_indices, scalar_var, indexes, + user); + return; + } + case SpvOpLoad: { + Instruction* value = + LoadAccessChainToVar(scalar_var, indexes, user); + loads_to_component_values->insert({user, value}); + return; + } + default: + break; + } + }); +} + +void InterfaceVariableScalarReplacement::CloneAnnotationForVariable( + Instruction* annotation_inst, uint32_t var_id) { + assert(annotation_inst->opcode() == SpvOpDecorate || + annotation_inst->opcode() == SpvOpDecorateId || + annotation_inst->opcode() == SpvOpDecorateString); + std::unique_ptr new_inst(annotation_inst->Clone(context())); + new_inst->SetInOperand(0, {var_id}); + context()->AddAnnotationInst(std::move(new_inst)); +} + +bool InterfaceVariableScalarReplacement::ReplaceInterfaceVarInEntryPoint( + Instruction* interface_var, Instruction* entry_point, + uint32_t scalar_var_id) { + analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr(); + uint32_t interface_var_id = interface_var->result_id(); + if (interface_vars_removed_from_entry_point_operands_.find( + interface_var_id) != + interface_vars_removed_from_entry_point_operands_.end()) { + entry_point->AddOperand({SPV_OPERAND_TYPE_ID, {scalar_var_id}}); + def_use_mgr->AnalyzeInstUse(entry_point); + return true; + } + + bool success = !entry_point->WhileEachInId( + [&interface_var_id, &scalar_var_id](uint32_t* id) { + if (*id == interface_var_id) { + *id = scalar_var_id; + return false; + } + return true; + }); + if (!success) { + std::string message( + "interface variable is not an operand of the entry point"); + message += "\n " + interface_var->PrettyPrint( + SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES); + message += "\n " + entry_point->PrettyPrint( + SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES); + context()->consumer()(SPV_MSG_ERROR, "", {0, 0, 0}, message.c_str()); + return false; + } + + def_use_mgr->AnalyzeInstUse(entry_point); + interface_vars_removed_from_entry_point_operands_.insert(interface_var_id); + return true; +} + +uint32_t InterfaceVariableScalarReplacement::GetPointeeTypeIdOfVar( + Instruction* var) { + assert(var->opcode() == SpvOpVariable); + + uint32_t ptr_type_id = var->type_id(); + analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr(); + Instruction* ptr_type_inst = def_use_mgr->GetDef(ptr_type_id); + + assert(ptr_type_inst->opcode() == SpvOpTypePointer && + "Variable must have a pointer type."); + return ptr_type_inst->GetSingleWordInOperand(kOpTypePtrTypeInOperandIndex); +} + +void InterfaceVariableScalarReplacement::StoreComponentOfValueToScalarVar( + uint32_t value_id, const std::vector& component_indices, + Instruction* scalar_var, const uint32_t* extra_array_index, + Instruction* insert_before) { + uint32_t component_type_id = GetPointeeTypeIdOfVar(scalar_var); + Instruction* ptr = scalar_var; + if (extra_array_index) { + auto* ty_mgr = context()->get_type_mgr(); + analysis::Array* array_type = ty_mgr->GetType(component_type_id)->AsArray(); + assert(array_type != nullptr); + component_type_id = ty_mgr->GetTypeInstruction(array_type->element_type()); + ptr = CreateAccessChainWithIndex(component_type_id, scalar_var, + *extra_array_index, insert_before); + } + + StoreComponentOfValueTo(component_type_id, value_id, component_indices, ptr, + extra_array_index, insert_before); +} + +Instruction* InterfaceVariableScalarReplacement::LoadScalarVar( + Instruction* scalar_var, const uint32_t* extra_array_index, + Instruction* insert_before) { + uint32_t component_type_id = GetPointeeTypeIdOfVar(scalar_var); + Instruction* ptr = scalar_var; + if (extra_array_index) { + auto* ty_mgr = context()->get_type_mgr(); + analysis::Array* array_type = ty_mgr->GetType(component_type_id)->AsArray(); + assert(array_type != nullptr); + component_type_id = ty_mgr->GetTypeInstruction(array_type->element_type()); + ptr = CreateAccessChainWithIndex(component_type_id, scalar_var, + *extra_array_index, insert_before); + } + + return CreateLoad(component_type_id, ptr, insert_before); +} + +Instruction* InterfaceVariableScalarReplacement::CreateLoad( + uint32_t type_id, Instruction* ptr, Instruction* insert_before) { + std::unique_ptr load( + new Instruction(context(), SpvOpLoad, type_id, TakeNextId(), + std::initializer_list{ + {SPV_OPERAND_TYPE_ID, {ptr->result_id()}}})); + Instruction* load_inst = load.get(); + context()->get_def_use_mgr()->AnalyzeInstDefUse(load_inst); + insert_before->InsertBefore(std::move(load)); + return load_inst; +} + +void InterfaceVariableScalarReplacement::StoreComponentOfValueTo( + uint32_t component_type_id, uint32_t value_id, + const std::vector& component_indices, Instruction* ptr, + const uint32_t* extra_array_index, Instruction* insert_before) { + std::unique_ptr composite_extract(CreateCompositeExtract( + component_type_id, value_id, component_indices, extra_array_index)); + + std::unique_ptr new_store( + new Instruction(context(), SpvOpStore)); + new_store->AddOperand({SPV_OPERAND_TYPE_ID, {ptr->result_id()}}); + new_store->AddOperand( + {SPV_OPERAND_TYPE_ID, {composite_extract->result_id()}}); + + analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr(); + def_use_mgr->AnalyzeInstDefUse(composite_extract.get()); + def_use_mgr->AnalyzeInstDefUse(new_store.get()); + + insert_before->InsertBefore(std::move(composite_extract)); + insert_before->InsertBefore(std::move(new_store)); +} + +Instruction* InterfaceVariableScalarReplacement::CreateCompositeExtract( + uint32_t type_id, uint32_t composite_id, + const std::vector& indexes, const uint32_t* extra_first_index) { + uint32_t component_id = TakeNextId(); + Instruction* composite_extract = new Instruction( + context(), SpvOpCompositeExtract, type_id, component_id, + std::initializer_list{{SPV_OPERAND_TYPE_ID, {composite_id}}}); + if (extra_first_index) { + composite_extract->AddOperand( + {SPV_OPERAND_TYPE_LITERAL_INTEGER, {*extra_first_index}}); + } + for (uint32_t index : indexes) { + composite_extract->AddOperand({SPV_OPERAND_TYPE_LITERAL_INTEGER, {index}}); + } + return composite_extract; +} + +void InterfaceVariableScalarReplacement:: + StoreComponentOfValueToAccessChainToScalarVar( + uint32_t value_id, const std::vector& component_indices, + Instruction* scalar_var, + const std::vector& access_chain_indices, + Instruction* insert_before) { + uint32_t component_type_id = GetPointeeTypeIdOfVar(scalar_var); + Instruction* ptr = scalar_var; + if (!access_chain_indices.empty()) { + ptr = CreateAccessChainToVar(component_type_id, scalar_var, + access_chain_indices, insert_before, + &component_type_id); + } + + StoreComponentOfValueTo(component_type_id, value_id, component_indices, ptr, + nullptr, insert_before); +} + +Instruction* InterfaceVariableScalarReplacement::LoadAccessChainToVar( + Instruction* var, const std::vector& indexes, + Instruction* insert_before) { + uint32_t component_type_id = GetPointeeTypeIdOfVar(var); + Instruction* ptr = var; + if (!indexes.empty()) { + ptr = CreateAccessChainToVar(component_type_id, var, indexes, insert_before, + &component_type_id); + } + + return CreateLoad(component_type_id, ptr, insert_before); +} + +Instruction* +InterfaceVariableScalarReplacement::CreateCompositeConstructForComponentOfLoad( + Instruction* load, uint32_t depth_to_component) { + analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr(); + uint32_t type_id = load->type_id(); + if (depth_to_component != 0) { + type_id = GetComponentTypeOfArrayMatrix(def_use_mgr, load->type_id(), + depth_to_component); + } + uint32_t new_id = context()->TakeNextId(); + std::unique_ptr new_composite_construct( + new Instruction(context(), SpvOpCompositeConstruct, type_id, new_id, {})); + Instruction* composite_construct = new_composite_construct.get(); + def_use_mgr->AnalyzeInstDefUse(composite_construct); + + // Insert |new_composite_construct| after |load|. When there are multiple + // recursive composite construct instructions for a load, we have to place the + // composite construct with a lower depth later because it constructs the + // composite that contains other composites with lower depths. + auto* insert_before = load->NextNode(); + while (true) { + auto itr = + composite_ids_to_component_depths.find(insert_before->result_id()); + if (itr == composite_ids_to_component_depths.end()) break; + if (itr->second <= depth_to_component) break; + insert_before = insert_before->NextNode(); + } + insert_before->InsertBefore(std::move(new_composite_construct)); + composite_ids_to_component_depths.insert({new_id, depth_to_component}); + return composite_construct; +} + +void InterfaceVariableScalarReplacement::AddComponentsToCompositesForLoads( + const std::unordered_map& + loads_to_component_values, + std::unordered_map* loads_to_composites, + uint32_t depth_to_component) { + analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr(); + for (auto& load_and_component_vale : loads_to_component_values) { + Instruction* load = load_and_component_vale.first; + Instruction* component_value = load_and_component_vale.second; + Instruction* composite_construct = nullptr; + auto itr = loads_to_composites->find(load); + if (itr == loads_to_composites->end()) { + composite_construct = + CreateCompositeConstructForComponentOfLoad(load, depth_to_component); + loads_to_composites->insert({load, composite_construct}); + } else { + composite_construct = itr->second; + } + composite_construct->AddOperand( + {SPV_OPERAND_TYPE_ID, {component_value->result_id()}}); + def_use_mgr->AnalyzeInstDefUse(composite_construct); + } +} + +uint32_t InterfaceVariableScalarReplacement::GetArrayType( + uint32_t elem_type_id, uint32_t array_length) { + analysis::Type* elem_type = context()->get_type_mgr()->GetType(elem_type_id); + uint32_t array_length_id = + context()->get_constant_mgr()->GetUIntConst(array_length); + analysis::Array array_type( + elem_type, + analysis::Array::LengthInfo{array_length_id, {0, array_length}}); + return context()->get_type_mgr()->GetTypeInstruction(&array_type); +} + +uint32_t InterfaceVariableScalarReplacement::GetPointerType( + uint32_t type_id, SpvStorageClass storage_class) { + analysis::Type* type = context()->get_type_mgr()->GetType(type_id); + analysis::Pointer ptr_type(type, storage_class); + return context()->get_type_mgr()->GetTypeInstruction(&ptr_type); +} + +InterfaceVariableScalarReplacement::NestedCompositeComponents +InterfaceVariableScalarReplacement::CreateScalarInterfaceVarsForArray( + Instruction* interface_var_type, SpvStorageClass storage_class, + uint32_t extra_array_length) { + assert(interface_var_type->opcode() == SpvOpTypeArray); + + analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr(); + uint32_t array_length = GetArrayLength(def_use_mgr, interface_var_type); + Instruction* elem_type = GetArrayElementType(def_use_mgr, interface_var_type); + + NestedCompositeComponents scalar_vars; + while (array_length > 0) { + NestedCompositeComponents scalar_vars_for_element = + CreateScalarInterfaceVarsForReplacement(elem_type, storage_class, + extra_array_length); + scalar_vars.AddComponent(scalar_vars_for_element); + --array_length; + } + return scalar_vars; +} + +InterfaceVariableScalarReplacement::NestedCompositeComponents +InterfaceVariableScalarReplacement::CreateScalarInterfaceVarsForMatrix( + Instruction* interface_var_type, SpvStorageClass storage_class, + uint32_t extra_array_length) { + assert(interface_var_type->opcode() == SpvOpTypeMatrix); + + analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr(); + uint32_t column_count = interface_var_type->GetSingleWordInOperand( + kOpTypeMatrixColCountInOperandIndex); + Instruction* column_type = + GetMatrixColumnType(def_use_mgr, interface_var_type); + + NestedCompositeComponents scalar_vars; + while (column_count > 0) { + NestedCompositeComponents scalar_vars_for_column = + CreateScalarInterfaceVarsForReplacement(column_type, storage_class, + extra_array_length); + scalar_vars.AddComponent(scalar_vars_for_column); + --column_count; + } + return scalar_vars; +} + +InterfaceVariableScalarReplacement::NestedCompositeComponents +InterfaceVariableScalarReplacement::CreateScalarInterfaceVarsForReplacement( + Instruction* interface_var_type, SpvStorageClass storage_class, + uint32_t extra_array_length) { + // Handle array case. + if (interface_var_type->opcode() == SpvOpTypeArray) { + return CreateScalarInterfaceVarsForArray(interface_var_type, storage_class, + extra_array_length); + } + + // Handle matrix case. + if (interface_var_type->opcode() == SpvOpTypeMatrix) { + return CreateScalarInterfaceVarsForMatrix(interface_var_type, storage_class, + extra_array_length); + } + + // Handle scalar or vector case. + NestedCompositeComponents scalar_var; + uint32_t type_id = interface_var_type->result_id(); + if (extra_array_length != 0) { + type_id = GetArrayType(type_id, extra_array_length); + } + uint32_t ptr_type_id = + context()->get_type_mgr()->FindPointerToType(type_id, storage_class); + uint32_t id = TakeNextId(); + std::unique_ptr variable( + new Instruction(context(), SpvOpVariable, ptr_type_id, id, + std::initializer_list{ + {SPV_OPERAND_TYPE_STORAGE_CLASS, + {static_cast(storage_class)}}})); + scalar_var.SetSingleComponentVariable(variable.get()); + context()->AddGlobalValue(std::move(variable)); + return scalar_var; +} + +Instruction* InterfaceVariableScalarReplacement::GetTypeOfVariable( + Instruction* var) { + uint32_t pointee_type_id = GetPointeeTypeIdOfVar(var); + analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr(); + return def_use_mgr->GetDef(pointee_type_id); +} + +Pass::Status InterfaceVariableScalarReplacement::Process() { + Pass::Status status = Status::SuccessWithoutChange; + for (Instruction& entry_point : get_module()->entry_points()) { + status = + CombineStatus(status, ReplaceInterfaceVarsWithScalars(entry_point)); + } + return status; +} + +bool InterfaceVariableScalarReplacement:: + ReportErrorIfHasExtraArraynessForOtherEntry(Instruction* var) { + if (vars_with_extra_arrayness.find(var) == vars_with_extra_arrayness.end()) + return false; + + std::string message( + "A variable is arrayed for an entry point but it is not " + "arrayed for another entry point"); + message += + "\n " + var->PrettyPrint(SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES); + context()->consumer()(SPV_MSG_ERROR, "", {0, 0, 0}, message.c_str()); + return true; +} + +bool InterfaceVariableScalarReplacement:: + ReportErrorIfHasNoExtraArraynessForOtherEntry(Instruction* var) { + if (vars_without_extra_arrayness.find(var) == + vars_without_extra_arrayness.end()) + return false; + + std::string message( + "A variable is not arrayed for an entry point but it is " + "arrayed for another entry point"); + message += + "\n " + var->PrettyPrint(SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES); + context()->consumer()(SPV_MSG_ERROR, "", {0, 0, 0}, message.c_str()); + return true; +} + +Pass::Status +InterfaceVariableScalarReplacement::ReplaceInterfaceVarsWithScalars( + Instruction& entry_point) { + std::vector interface_vars = + CollectInterfaceVariables(entry_point); + + Pass::Status status = Status::SuccessWithoutChange; + for (Instruction* interface_var : interface_vars) { + uint32_t location, component; + if (!GetVariableLocation(interface_var, &location)) continue; + if (!GetVariableComponent(interface_var, &component)) component = 0; + + Instruction* interface_var_type = GetTypeOfVariable(interface_var); + uint32_t extra_array_length = 0; + if (HasExtraArrayness(entry_point, interface_var)) { + extra_array_length = + GetArrayLength(context()->get_def_use_mgr(), interface_var_type); + interface_var_type = + GetArrayElementType(context()->get_def_use_mgr(), interface_var_type); + vars_with_extra_arrayness.insert(interface_var); + } else { + vars_without_extra_arrayness.insert(interface_var); + } + + if (!CheckExtraArraynessConflictBetweenEntries(interface_var, + extra_array_length != 0)) { + return Pass::Status::Failure; + } + + if (interface_var_type->opcode() != SpvOpTypeArray && + interface_var_type->opcode() != SpvOpTypeMatrix) { + continue; + } + + if (!ReplaceInterfaceVariableWithScalars(interface_var, interface_var_type, + location, component, + extra_array_length)) { + return Pass::Status::Failure; + } + status = Pass::Status::SuccessWithChange; + } + + return status; +} + +} // namespace opt +} // namespace spvtools diff --git a/3rdparty/spirv-tools/source/opt/interface_var_sroa.h b/3rdparty/spirv-tools/source/opt/interface_var_sroa.h new file mode 100644 index 000000000..23baad0ad --- /dev/null +++ b/3rdparty/spirv-tools/source/opt/interface_var_sroa.h @@ -0,0 +1,401 @@ +// Copyright (c) 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_OPT_INTERFACE_VAR_SROA_H_ +#define SOURCE_OPT_INTERFACE_VAR_SROA_H_ + +#include + +#include "source/opt/pass.h" + +namespace spvtools { +namespace opt { + +// See optimizer.hpp for documentation. +// +// Note that the current implementation of this pass covers only store, load, +// access chain instructions for the interface variables. Supporting other types +// of instructions is a future work. +class InterfaceVariableScalarReplacement : public Pass { + public: + InterfaceVariableScalarReplacement() {} + + const char* name() const override { + return "interface-variable-scalar-replacement"; + } + Status Process() override; + + IRContext::Analysis GetPreservedAnalyses() override { + return IRContext::kAnalysisDecorations | IRContext::kAnalysisDefUse | + IRContext::kAnalysisConstants | IRContext::kAnalysisTypes; + } + + private: + // A struct containing components of a composite variable. If the composite + // consists of multiple or recursive components, |component_variable| is + // nullptr and |nested_composite_components| keeps the components. If it has a + // single component, |nested_composite_components| is empty and + // |component_variable| is the component. Note that each element of + // |nested_composite_components| has the NestedCompositeComponents struct as + // its type that can recursively keep the components. + struct NestedCompositeComponents { + NestedCompositeComponents() : component_variable(nullptr) {} + + bool HasMultipleComponents() const { + return !nested_composite_components.empty(); + } + + const std::vector& GetComponents() const { + return nested_composite_components; + } + + void AddComponent(const NestedCompositeComponents& component) { + nested_composite_components.push_back(component); + } + + Instruction* GetComponentVariable() const { return component_variable; } + + void SetSingleComponentVariable(Instruction* var) { + component_variable = var; + } + + private: + std::vector nested_composite_components; + Instruction* component_variable; + }; + + // Collects all interface variables used by the |entry_point|. + std::vector CollectInterfaceVariables(Instruction& entry_point); + + // Returns whether |var| has the extra arrayness for the entry point + // |entry_point| or not. + bool HasExtraArrayness(Instruction& entry_point, Instruction* var); + + // Finds a Location BuiltIn decoration of |var| and returns it via + // |location|. Returns true whether the location exists or not. + bool GetVariableLocation(Instruction* var, uint32_t* location); + + // Finds a Component BuiltIn decoration of |var| and returns it via + // |component|. Returns true whether the component exists or not. + bool GetVariableComponent(Instruction* var, uint32_t* component); + + // Returns the interface variable instruction whose result id is + // |interface_var_id|. + Instruction* GetInterfaceVariable(uint32_t interface_var_id); + + // Returns the type of |var| as an instruction. + Instruction* GetTypeOfVariable(Instruction* var); + + // Replaces an interface variable |interface_var| whose type is + // |interface_var_type| with scalars and returns whether it succeeds or not. + // |location| is the value of Location Decoration for |interface_var|. + // |component| is the value of Component Decoration for |interface_var|. + // If |extra_array_length| is 0, it means |interface_var| has a Patch + // decoration. Otherwise, |extra_array_length| denotes the length of the extra + // array of |interface_var|. + bool ReplaceInterfaceVariableWithScalars(Instruction* interface_var, + Instruction* interface_var_type, + uint32_t location, + uint32_t component, + uint32_t extra_array_length); + + // Creates scalar variables with the storage classe |storage_class| to replace + // an interface variable whose type is |interface_var_type|. If + // |extra_array_length| is not zero, adds the extra arrayness to the created + // scalar variables. + NestedCompositeComponents CreateScalarInterfaceVarsForReplacement( + Instruction* interface_var_type, SpvStorageClass storage_class, + uint32_t extra_array_length); + + // Creates scalar variables with the storage classe |storage_class| to replace + // the interface variable whose type is OpTypeArray |interface_var_type| with. + // If |extra_array_length| is not zero, adds the extra arrayness to all the + // scalar variables. + NestedCompositeComponents CreateScalarInterfaceVarsForArray( + Instruction* interface_var_type, SpvStorageClass storage_class, + uint32_t extra_array_length); + + // Creates scalar variables with the storage classe |storage_class| to replace + // the interface variable whose type is OpTypeMatrix |interface_var_type| + // with. If |extra_array_length| is not zero, adds the extra arrayness to all + // the scalar variables. + NestedCompositeComponents CreateScalarInterfaceVarsForMatrix( + Instruction* interface_var_type, SpvStorageClass storage_class, + uint32_t extra_array_length); + + // Recursively adds Location and Component decorations to variables in + // |vars| with |location| and |component|. Increases |location| by one after + // it actually adds Location and Component decorations for a variable. + void AddLocationAndComponentDecorations(const NestedCompositeComponents& vars, + uint32_t* location, + uint32_t component); + + // Replaces the interface variable |interface_var| with + // |scalar_interface_vars| and returns whether it succeeds or not. + // |extra_arrayness| is the extra arrayness of the interface variable. + // |scalar_interface_vars| contains the nested variables to replace the + // interface variable with. + bool ReplaceInterfaceVarWith( + Instruction* interface_var, uint32_t extra_arrayness, + const NestedCompositeComponents& scalar_interface_vars); + + // Replaces |interface_var| in the operands of instructions + // |interface_var_users| with |scalar_interface_vars|. This is a recursive + // method and |interface_var_component_indices| is used to specify which + // recursive component of |interface_var| is replaced. Returns composite + // construct instructions to be replaced with load instructions of + // |interface_var_users| via |loads_to_composites|. Returns composite + // construct instructions to be replaced with load instructions of access + // chain instructions in |interface_var_users| via + // |loads_for_access_chain_to_composites|. + bool ReplaceComponentsOfInterfaceVarWith( + Instruction* interface_var, + const std::vector& interface_var_users, + const NestedCompositeComponents& scalar_interface_vars, + std::vector& interface_var_component_indices, + const uint32_t* extra_array_index, + std::unordered_map* loads_to_composites, + std::unordered_map* + loads_for_access_chain_to_composites); + + // Replaces |interface_var| in the operands of instructions + // |interface_var_users| with |components| that is a vector of components for + // the interface variable |interface_var|. This is a recursive method and + // |interface_var_component_indices| is used to specify which recursive + // component of |interface_var| is replaced. Returns composite construct + // instructions to be replaced with load instructions of |interface_var_users| + // via |loads_to_composites|. Returns composite construct instructions to be + // replaced with load instructions of access chain instructions in + // |interface_var_users| via |loads_for_access_chain_to_composites|. + bool ReplaceMultipleComponentsOfInterfaceVarWith( + Instruction* interface_var, + const std::vector& interface_var_users, + const std::vector& components, + std::vector& interface_var_component_indices, + const uint32_t* extra_array_index, + std::unordered_map* loads_to_composites, + std::unordered_map* + loads_for_access_chain_to_composites); + + // Replaces a component of |interface_var| that is used as an operand of + // instruction |interface_var_user| with |scalar_var|. + // |interface_var_component_indices| is a vector of recursive indices for + // which recursive component of |interface_var| is replaced. If + // |interface_var_user| is a load, returns the component value via + // |loads_to_component_values|. If |interface_var_user| is an access chain, + // returns the component value for loads of |interface_var_user| via + // |loads_for_access_chain_to_component_values|. + bool ReplaceComponentOfInterfaceVarWith( + Instruction* interface_var, Instruction* interface_var_user, + Instruction* scalar_var, + const std::vector& interface_var_component_indices, + const uint32_t* extra_array_index, + std::unordered_map* loads_to_component_values, + std::unordered_map* + loads_for_access_chain_to_component_values); + + // Creates instructions to load |scalar_var| and inserts them before + // |insert_before|. If |extra_array_index| is not null, they load + // |extra_array_index| th component of |scalar_var| instead of |scalar_var| + // itself. + Instruction* LoadScalarVar(Instruction* scalar_var, + const uint32_t* extra_array_index, + Instruction* insert_before); + + // Creates instructions to load an access chain to |var| and inserts them + // before |insert_before|. |Indexes| will be Indexes operand of the access + // chain. + Instruction* LoadAccessChainToVar(Instruction* var, + const std::vector& indexes, + Instruction* insert_before); + + // Creates instructions to store a component of an aggregate whose id is + // |value_id| to an access chain to |scalar_var| and inserts the created + // instructions before |insert_before|. To get the component, recursively + // traverses the aggregate with |component_indices| as indexes. + // Numbers in |access_chain_indices| are the Indexes operand of the access + // chain to |scalar_var| + void StoreComponentOfValueToAccessChainToScalarVar( + uint32_t value_id, const std::vector& component_indices, + Instruction* scalar_var, + const std::vector& access_chain_indices, + Instruction* insert_before); + + // Creates instructions to store a component of an aggregate whose id is + // |value_id| to |scalar_var| and inserts the created instructions before + // |insert_before|. To get the component, recursively traverses the aggregate + // using |extra_array_index| and |component_indices| as indexes. + void StoreComponentOfValueToScalarVar( + uint32_t value_id, const std::vector& component_indices, + Instruction* scalar_var, const uint32_t* extra_array_index, + Instruction* insert_before); + + // Creates instructions to store a component of an aggregate whose id is + // |value_id| to |ptr| and inserts the created instructions before + // |insert_before|. To get the component, recursively traverses the aggregate + // using |extra_array_index| and |component_indices| as indexes. + // |component_type_id| is the id of the type instruction of the component. + void StoreComponentOfValueTo(uint32_t component_type_id, uint32_t value_id, + const std::vector& component_indices, + Instruction* ptr, + const uint32_t* extra_array_index, + Instruction* insert_before); + + // Creates new OpCompositeExtract with |type_id| for Result Type, + // |composite_id| for Composite operand, and |indexes| for Indexes operands. + // If |extra_first_index| is not nullptr, uses it as the first Indexes + // operand. + Instruction* CreateCompositeExtract(uint32_t type_id, uint32_t composite_id, + const std::vector& indexes, + const uint32_t* extra_first_index); + + // Creates a new OpLoad whose Result Type is |type_id| and Pointer operand is + // |ptr|. Inserts the new instruction before |insert_before|. + Instruction* CreateLoad(uint32_t type_id, Instruction* ptr, + Instruction* insert_before); + + // Clones an annotation instruction |annotation_inst| and sets the target + // operand of the new annotation instruction as |var_id|. + void CloneAnnotationForVariable(Instruction* annotation_inst, + uint32_t var_id); + + // Replaces the interface variable |interface_var| in the operands of the + // entry point |entry_point| with |scalar_var_id|. If it cannot find + // |interface_var| from the operands of the entry point |entry_point|, adds + // |scalar_var_id| as an operand of the entry point |entry_point|. + bool ReplaceInterfaceVarInEntryPoint(Instruction* interface_var, + Instruction* entry_point, + uint32_t scalar_var_id); + + // Creates an access chain instruction whose Base operand is |var| and Indexes + // operand is |index|. |component_type_id| is the id of the type instruction + // that is the type of component. Inserts the new access chain before + // |insert_before|. + Instruction* CreateAccessChainWithIndex(uint32_t component_type_id, + Instruction* var, uint32_t index, + Instruction* insert_before); + + // Returns the pointee type of the type of variable |var|. + uint32_t GetPointeeTypeIdOfVar(Instruction* var); + + // Replaces the access chain |access_chain| and its users with a new access + // chain that points |scalar_var| as the Base operand having + // |interface_var_component_indices| as Indexes operands and users of the new + // access chain. When some of the users are load instructions, returns the + // original load instruction to the new instruction that loads a component of + // the original load value via |loads_to_component_values|. + void ReplaceAccessChainWith( + Instruction* access_chain, + const std::vector& interface_var_component_indices, + Instruction* scalar_var, + std::unordered_map* + loads_to_component_values); + + // Assuming that |access_chain| is an access chain instruction whose Base + // operand is |base_access_chain|, replaces the operands of |access_chain| + // with operands of |base_access_chain| and Indexes operands of + // |access_chain|. + void UseBaseAccessChainForAccessChain(Instruction* access_chain, + Instruction* base_access_chain); + + // Creates composite construct instructions for load instructions that are the + // keys of |loads_to_component_values| if no such composite construct + // instructions exist. Adds a component of the composite as an operand of the + // created composite construct instruction. Each value of + // |loads_to_component_values| is the component. Returns the created composite + // construct instructions using |loads_to_composites|. |depth_to_component| is + // the number of recursive access steps to get the component from the + // composite. + void AddComponentsToCompositesForLoads( + const std::unordered_map& + loads_to_component_values, + std::unordered_map* loads_to_composites, + uint32_t depth_to_component); + + // Creates a composite construct instruction for a component of the value of + // instruction |load| in |depth_to_component| th recursive depth and inserts + // it after |load|. + Instruction* CreateCompositeConstructForComponentOfLoad( + Instruction* load, uint32_t depth_to_component); + + // Creates a new access chain instruction that points to variable |var| whose + // type is the instruction with |var_type_id| and inserts it before + // |insert_before|. The new access chain will have |index_ids| for Indexes + // operands. Returns the type id of the component that is pointed by the new + // access chain via |component_type_id|. + Instruction* CreateAccessChainToVar(uint32_t var_type_id, Instruction* var, + const std::vector& index_ids, + Instruction* insert_before, + uint32_t* component_type_id); + + // Returns the result id of OpTypeArray instrunction whose Element Type + // operand is |elem_type_id| and Length operand is |array_length|. + uint32_t GetArrayType(uint32_t elem_type_id, uint32_t array_length); + + // Returns the result id of OpTypePointer instrunction whose Type + // operand is |type_id| and Storage Class operand is |storage_class|. + uint32_t GetPointerType(uint32_t type_id, SpvStorageClass storage_class); + + // Kills an instrunction |inst| and its users. + void KillInstructionAndUsers(Instruction* inst); + + // Kills a vector of instrunctions |insts| and their users. + void KillInstructionsAndUsers(const std::vector& insts); + + // Kills all OpDecorate instructions for Location and Component of the + // variable whose id is |var_id|. + void KillLocationAndComponentDecorations(uint32_t var_id); + + // If |var| has the extra arrayness for an entry point, reports an error and + // returns true. Otherwise, returns false. + bool ReportErrorIfHasExtraArraynessForOtherEntry(Instruction* var); + + // If |var| does not have the extra arrayness for an entry point, reports an + // error and returns true. Otherwise, returns false. + bool ReportErrorIfHasNoExtraArraynessForOtherEntry(Instruction* var); + + // If |interface_var| has the extra arrayness for an entry point but it does + // not have one for another entry point, reports an error and returns false. + // Otherwise, returns true. |has_extra_arrayness| denotes whether it has an + // extra arrayness for an entry point or not. + bool CheckExtraArraynessConflictBetweenEntries(Instruction* interface_var, + bool has_extra_arrayness); + + // Conducts the scalar replacement for the interface variables used by the + // |entry_point|. + Pass::Status ReplaceInterfaceVarsWithScalars(Instruction& entry_point); + + // A set of interface variable ids that were already removed from operands of + // the entry point. + std::unordered_set + interface_vars_removed_from_entry_point_operands_; + + // A mapping from ids of new composite construct instructions that load + // instructions are replaced with to the recursive depth of the component of + // load that the new component construct instruction is used for. + std::unordered_map composite_ids_to_component_depths; + + // A set of interface variables with the extra arrayness for any of the entry + // points. + std::unordered_set vars_with_extra_arrayness; + + // A set of interface variables without the extra arrayness for any of the + // entry points. + std::unordered_set vars_without_extra_arrayness; +}; + +} // namespace opt +} // namespace spvtools + +#endif // SOURCE_OPT_INTERFACE_VAR_SROA_H_ diff --git a/3rdparty/spirv-tools/source/opt/ir_context.h b/3rdparty/spirv-tools/source/opt/ir_context.h index f9f51532b..2f27942b4 100644 --- a/3rdparty/spirv-tools/source/opt/ir_context.h +++ b/3rdparty/spirv-tools/source/opt/ir_context.h @@ -1094,6 +1094,9 @@ void IRContext::AddDebug2Inst(std::unique_ptr&& d) { id_to_name_->insert({d->GetSingleWordInOperand(0), d.get()}); } } + if (AreAnalysesValid(kAnalysisDefUse)) { + get_def_use_mgr()->AnalyzeInstDefUse(d.get()); + } module()->AddDebug2Inst(std::move(d)); } diff --git a/3rdparty/spirv-tools/source/opt/local_access_chain_convert_pass.cpp b/3rdparty/spirv-tools/source/opt/local_access_chain_convert_pass.cpp index 0c6d0c24c..559856ec6 100644 --- a/3rdparty/spirv-tools/source/opt/local_access_chain_convert_pass.cpp +++ b/3rdparty/spirv-tools/source/opt/local_access_chain_convert_pass.cpp @@ -28,8 +28,6 @@ namespace { const uint32_t kStoreValIdInIdx = 1; const uint32_t kAccessChainPtrIdInIdx = 0; -const uint32_t kConstantValueInIdx = 0; -const uint32_t kTypeIntWidthInIdx = 0; } // anonymous namespace @@ -67,7 +65,19 @@ void LocalAccessChainConvertPass::AppendConstantOperands( ptrInst->ForEachInId([&iidIdx, &in_opnds, this](const uint32_t* iid) { if (iidIdx > 0) { const Instruction* cInst = get_def_use_mgr()->GetDef(*iid); - uint32_t val = cInst->GetSingleWordInOperand(kConstantValueInIdx); + const auto* constant_value = + context()->get_constant_mgr()->GetConstantFromInst(cInst); + assert(constant_value != nullptr && + "Expecting the index to be a constant."); + + // We take the sign extended value because OpAccessChain interprets the + // index as signed. + int64_t long_value = constant_value->GetSignExtendedValue(); + assert(long_value <= UINT32_MAX && long_value >= 0 && + "The index value is too large for a composite insert or extract " + "instruction."); + + uint32_t val = static_cast(long_value); in_opnds->push_back( {spv_operand_type_t::SPV_OPERAND_TYPE_LITERAL_INTEGER, {val}}); } @@ -169,13 +179,16 @@ bool LocalAccessChainConvertPass::GenAccessChainStoreReplacement( return true; } -bool LocalAccessChainConvertPass::IsConstantIndexAccessChain( +bool LocalAccessChainConvertPass::Is32BitConstantIndexAccessChain( const Instruction* acp) const { uint32_t inIdx = 0; return acp->WhileEachInId([&inIdx, this](const uint32_t* tid) { if (inIdx > 0) { Instruction* opInst = get_def_use_mgr()->GetDef(*tid); if (opInst->opcode() != SpvOpConstant) return false; + const auto* index = + context()->get_constant_mgr()->GetConstantFromInst(opInst); + if (index->GetSignExtendedValue() > UINT32_MAX) return false; } ++inIdx; return true; @@ -231,7 +244,7 @@ void LocalAccessChainConvertPass::FindTargetVars(Function* func) { break; } // Rule out variables accessed with non-constant indices - if (!IsConstantIndexAccessChain(ptrInst)) { + if (!Is32BitConstantIndexAccessChain(ptrInst)) { seen_non_target_vars_.insert(varId); seen_target_vars_.erase(varId); break; @@ -349,12 +362,6 @@ bool LocalAccessChainConvertPass::AllExtensionsSupported() const { } Pass::Status LocalAccessChainConvertPass::ProcessImpl() { - // If non-32-bit integer type in module, terminate processing - // TODO(): Handle non-32-bit integer constants in access chains - for (const Instruction& inst : get_module()->types_values()) - if (inst.opcode() == SpvOpTypeInt && - inst.GetSingleWordInOperand(kTypeIntWidthInIdx) != 32) - return Status::SuccessWithoutChange; // Do not process if module contains OpGroupDecorate. Additional // support required in KillNamesAndDecorates(). // TODO(greg-lunarg): Add support for OpGroupDecorate diff --git a/3rdparty/spirv-tools/source/opt/local_access_chain_convert_pass.h b/3rdparty/spirv-tools/source/opt/local_access_chain_convert_pass.h index a51660f10..8548e164a 100644 --- a/3rdparty/spirv-tools/source/opt/local_access_chain_convert_pass.h +++ b/3rdparty/spirv-tools/source/opt/local_access_chain_convert_pass.h @@ -95,7 +95,8 @@ class LocalAccessChainConvertPass : public MemPass { Instruction* original_load); // Return true if all indices of access chain |acp| are OpConstant integers - bool IsConstantIndexAccessChain(const Instruction* acp) const; + // whose values can fit into an unsigned 32-bit value. + bool Is32BitConstantIndexAccessChain(const Instruction* acp) const; // Identify all function scope variables of target type which are // accessed only with loads, stores and access chains with constant diff --git a/3rdparty/spirv-tools/source/opt/optimizer.cpp b/3rdparty/spirv-tools/source/opt/optimizer.cpp index 051d573d8..29761518f 100644 --- a/3rdparty/spirv-tools/source/opt/optimizer.cpp +++ b/3rdparty/spirv-tools/source/opt/optimizer.cpp @@ -1020,6 +1020,11 @@ Optimizer::PassToken CreateConvertToSampledImagePass( MakeUnique(descriptor_set_binding_pairs)); } +Optimizer::PassToken CreateInterfaceVariableScalarReplacementPass() { + return MakeUnique( + MakeUnique()); +} + Optimizer::PassToken CreateRemoveDontInlinePass() { return MakeUnique( MakeUnique()); diff --git a/3rdparty/spirv-tools/source/opt/passes.h b/3rdparty/spirv-tools/source/opt/passes.h index facaa410e..21354c77b 100644 --- a/3rdparty/spirv-tools/source/opt/passes.h +++ b/3rdparty/spirv-tools/source/opt/passes.h @@ -49,6 +49,7 @@ #include "source/opt/inst_bindless_check_pass.h" #include "source/opt/inst_buff_addr_check_pass.h" #include "source/opt/inst_debug_printf_pass.h" +#include "source/opt/interface_var_sroa.h" #include "source/opt/interp_fixup_pass.h" #include "source/opt/licm_pass.h" #include "source/opt/local_access_chain_convert_pass.h" diff --git a/3rdparty/spirv-tools/source/opt/replace_desc_array_access_using_var_index.cpp b/3rdparty/spirv-tools/source/opt/replace_desc_array_access_using_var_index.cpp index 4cadf600e..e97593ef3 100644 --- a/3rdparty/spirv-tools/source/opt/replace_desc_array_access_using_var_index.cpp +++ b/3rdparty/spirv-tools/source/opt/replace_desc_array_access_using_var_index.cpp @@ -95,7 +95,7 @@ void ReplaceDescArrayAccessUsingVarIndex::ReplaceUsersOfAccessChain( CollectRecursiveUsersWithConcreteType(access_chain, &final_users); for (auto* inst : final_users) { std::deque insts_to_be_cloned = - CollectRequiredImageInsts(inst); + CollectRequiredImageAndAccessInsts(inst); ReplaceNonUniformAccessWithSwitchCase( inst, access_chain, number_of_elements, insts_to_be_cloned); } @@ -121,8 +121,8 @@ void ReplaceDescArrayAccessUsingVarIndex::CollectRecursiveUsersWithConcreteType( } std::deque -ReplaceDescArrayAccessUsingVarIndex::CollectRequiredImageInsts( - Instruction* user_of_image_insts) const { +ReplaceDescArrayAccessUsingVarIndex::CollectRequiredImageAndAccessInsts( + Instruction* user) const { std::unordered_set seen_inst_ids; std::queue work_list; @@ -131,21 +131,23 @@ ReplaceDescArrayAccessUsingVarIndex::CollectRequiredImageInsts( if (!seen_inst_ids.insert(*idp).second) return; Instruction* operand = get_def_use_mgr()->GetDef(*idp); if (context()->get_instr_block(operand) != nullptr && - HasImageOrImagePtrType(operand)) { + (HasImageOrImagePtrType(operand) || + operand->opcode() == SpvOpAccessChain || + operand->opcode() == SpvOpInBoundsAccessChain)) { work_list.push(operand); } }; - std::deque required_image_insts; - required_image_insts.push_front(user_of_image_insts); - user_of_image_insts->ForEachInId(decision_to_include_operand); + std::deque required_insts; + required_insts.push_front(user); + user->ForEachInId(decision_to_include_operand); while (!work_list.empty()) { auto* inst_from_work_list = work_list.front(); work_list.pop(); - required_image_insts.push_front(inst_from_work_list); + required_insts.push_front(inst_from_work_list); inst_from_work_list->ForEachInId(decision_to_include_operand); } - return required_image_insts; + return required_insts; } bool ReplaceDescArrayAccessUsingVarIndex::HasImageOrImagePtrType( diff --git a/3rdparty/spirv-tools/source/opt/replace_desc_array_access_using_var_index.h b/3rdparty/spirv-tools/source/opt/replace_desc_array_access_using_var_index.h index 0c97f7eb2..51817c15f 100644 --- a/3rdparty/spirv-tools/source/opt/replace_desc_array_access_using_var_index.h +++ b/3rdparty/spirv-tools/source/opt/replace_desc_array_access_using_var_index.h @@ -76,11 +76,12 @@ class ReplaceDescArrayAccessUsingVarIndex : public Pass { void CollectRecursiveUsersWithConcreteType( Instruction* access_chain, std::vector* final_users) const; - // Recursively collects the operands of |user_of_image_insts| (and operands - // of the operands) whose result types are images/samplers or pointers/array/ - // struct of them and returns them. - std::deque CollectRequiredImageInsts( - Instruction* user_of_image_insts) const; + // Recursively collects the operands of |user| (and operands of the operands) + // whose result types are images/samplers (or pointers/arrays/ structs of + // them) and access chains instructions and returns them. The returned + // collection includes |user|. + std::deque CollectRequiredImageAndAccessInsts( + Instruction* user) const; // Returns whether result type of |inst| is an image/sampler/pointer of image // or sampler or not.