// Copyright (c) 2023-2025 Arm Ltd. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // Validates correctness of graph instructions. #include #include "source/opcode.h" #include "source/val/validate.h" #include "source/val/validation_state.h" namespace spvtools { namespace val { namespace { bool IsTensorArray(ValidationState_t& _, uint32_t id) { auto def = _.FindDef(id); if (!def || (def->opcode() != spv::Op::OpTypeArray && def->opcode() != spv::Op::OpTypeRuntimeArray)) { return false; } auto tdef = _.FindDef(def->word(2)); if (!tdef || tdef->opcode() != spv::Op::OpTypeTensorARM) { return false; } return true; } bool IsGraphInterfaceType(ValidationState_t& _, uint32_t id) { return _.IsTensorType(id) || IsTensorArray(_, id); } bool IsGraph(ValidationState_t& _, uint32_t id) { auto def = _.FindDef(id); if (!def || def->opcode() != spv::Op::OpGraphARM) { return false; } return true; } bool IsGraphType(ValidationState_t& _, uint32_t id) { auto def = _.FindDef(id); if (!def || def->opcode() != spv::Op::OpTypeGraphARM) { return false; } return true; } const uint32_t kGraphTypeIOStartWord = 3; uint32_t GraphTypeInstNumIO(const Instruction* inst) { return static_cast(inst->words().size()) - kGraphTypeIOStartWord; } uint32_t GraphTypeInstNumInputs(const Instruction* inst) { return inst->word(2); } uint32_t GraphTypeInstNumOutputs(const Instruction* inst) { return GraphTypeInstNumIO(inst) - GraphTypeInstNumInputs(inst); } uint32_t GraphTypeInstGetOutputAtIndex(const Instruction* inst, uint64_t index) { return inst->word(kGraphTypeIOStartWord + GraphTypeInstNumInputs(inst) + static_cast(index)); } uint32_t GraphTypeInstGetInputAtIndex(const Instruction* inst, uint64_t index) { return inst->word(kGraphTypeIOStartWord + static_cast(index)); } spv_result_t ValidateGraphType(ValidationState_t& _, const Instruction* inst) { // Check there are at least NumInputs types uint32_t NumInputs = GraphTypeInstNumInputs(inst); size_t NumIOTypes = GraphTypeInstNumIO(inst); if (NumIOTypes < NumInputs) { return _.diag(SPV_ERROR_INVALID_DATA, inst) << NumIOTypes << " I/O types were provided but the graph has " << NumInputs << " inputs."; } // Check there is at least one output if (NumIOTypes == NumInputs) { return _.diag(SPV_ERROR_INVALID_DATA, inst) << "A graph type must have at least one output."; } // Check all I/O types are graph interface type for (unsigned i = kGraphTypeIOStartWord; i < inst->words().size(); i++) { auto tid = inst->word(i); if (!IsGraphInterfaceType(_, tid)) { return _.diag(SPV_ERROR_INVALID_DATA, inst) << "I/O type " << _.getIdName(tid) << " is not a Graph Interface Type."; } } return SPV_SUCCESS; } spv_result_t ValidateGraphConstant(ValidationState_t& _, const Instruction* inst) { // Check Result Type if (!_.IsTensorType(inst->type_id())) { return _.diag(SPV_ERROR_INVALID_DATA, inst) << spvOpcodeString(inst->opcode()) << " must have a Result Type that is a tensor type."; } // Check the instruction is not preceded by another OpGraphConstantARM with // the same ID const uint32_t cst_id = inst->word(3); size_t inst_num = inst->LineNum() - 1; while (--inst_num) { auto prev_inst = &_.ordered_instructions()[inst_num]; if (prev_inst->opcode() == spv::Op::OpGraphConstantARM) { const uint32_t prev_cst_id = prev_inst->word(3); if (prev_cst_id == cst_id) { return _.diag(SPV_ERROR_INVALID_DATA, inst) << "No two OpGraphConstantARM instructions may have the same " "GraphConstantID"; } } } return SPV_SUCCESS; } spv_result_t ValidateGraphEntryPoint(ValidationState_t& _, const Instruction* inst) { // Graph must be an OpGraphARM uint32_t graph = inst->GetOperandAs(0); auto graph_inst = _.FindDef(graph); if (!IsGraph(_, graph)) { return _.diag(SPV_ERROR_INVALID_DATA, inst) << spvOpcodeString(inst->opcode()) << " Graph must be a OpGraphARM but found " << spvOpcodeString(graph_inst->opcode()) << "."; } // Check number of Interface IDs matches number of I/Os of graph auto graph_type_inst = _.FindDef(graph_inst->type_id()); size_t graph_type_num_io = GraphTypeInstNumIO(graph_type_inst); size_t graph_entry_point_num_interface_id = inst->operands().size() - 2; if (graph_type_inst->opcode() != spv::Op::OpTypeGraphARM) { // This is invalid but we want ValidateGraph to report a clear error // so stop validating the graph entry point instruction return SPV_SUCCESS; } if (graph_type_num_io != graph_entry_point_num_interface_id) { return _.diag(SPV_ERROR_INVALID_DATA, inst) << spvOpcodeString(inst->opcode()) << " Interface list contains " << graph_entry_point_num_interface_id << " IDs but Graph's type " << _.getIdName(graph_inst->type_id()) << " has " << graph_type_num_io << " inputs and outputs."; } // Check Interface IDs for (uint32_t i = 2; i < inst->operands().size(); i++) { uint32_t interface_id = inst->GetOperandAs(i); auto interface_inst = _.FindDef(interface_id); // Check interface IDs come from OpVariable if ((interface_inst->opcode() != spv::Op::OpVariable) || (interface_inst->GetOperandAs(2) != spv::StorageClass::UniformConstant)) { return _.diag(SPV_ERROR_INVALID_DATA, interface_inst) << spvOpcodeString(inst->opcode()) << " Interface ID " << _.getIdName(interface_id) << " must come from OpVariable with UniformConstant Storage " "Class."; } // Check type of interface variable matches type of the corresponding graph // I/O uint32_t corresponding_graph_io_type = graph_type_inst->GetOperandAs(i); uint32_t interface_ptr_type = interface_inst->type_id(); auto interface_ptr_inst = _.FindDef(interface_ptr_type); auto interface_pointee_type = interface_ptr_inst->GetOperandAs(2); if (interface_pointee_type != corresponding_graph_io_type) { return _.diag(SPV_ERROR_INVALID_DATA, inst) << spvOpcodeString(inst->opcode()) << " Interface ID type " << _.getIdName(interface_pointee_type) << " must match the type of the corresponding graph I/O " << _.getIdName(corresponding_graph_io_type); } } return SPV_SUCCESS; } spv_result_t ValidateGraph(ValidationState_t& _, const Instruction* inst) { // Result Type must be an OpTypeGraphARM if (!IsGraphType(_, inst->type_id())) { return _.diag(SPV_ERROR_INVALID_DATA, inst) << spvOpcodeString(inst->opcode()) << " Result Type must be an OpTypeGraphARM."; } return SPV_SUCCESS; } spv_result_t ValidateGraphInput(ValidationState_t& _, const Instruction* inst) { // Check type of InputIndex auto input_index_inst = _.FindDef(inst->GetOperandAs(2)); if (!input_index_inst || !_.IsIntScalarType(input_index_inst->type_id(), 32)) { return _.diag(SPV_ERROR_INVALID_DATA, inst) << spvOpcodeString(inst->opcode()) << " InputIndex must be a 32-bit integer."; } bool has_element_index = inst->operands().size() > 3; // Check type of ElementIndex if (has_element_index) { auto element_index_inst = _.FindDef(inst->GetOperandAs(3)); if (!element_index_inst || !_.IsIntScalarType(element_index_inst->type_id(), 32)) { return _.diag(SPV_ERROR_INVALID_DATA, inst) << spvOpcodeString(inst->opcode()) << " ElementIndex must be a 32-bit integer."; } } // Find graph definition size_t inst_num = inst->LineNum() - 1; auto graph_inst = &_.ordered_instructions()[inst_num]; while (--inst_num) { graph_inst = &_.ordered_instructions()[inst_num]; if (graph_inst->opcode() == spv::Op::OpGraphARM) { break; } } // Can the InputIndex be evaluated? // If not, there's nothing more we can validate here. uint64_t input_index; if (!_.EvalConstantValUint64(inst->GetOperandAs(2), &input_index)) { return SPV_SUCCESS; } auto const graph_type_inst = _.FindDef(graph_inst->type_id()); size_t graph_type_num_inputs = graph_type_inst->GetOperandAs(1); // Check InputIndex is in range if (input_index >= graph_type_num_inputs) { std::string disassembly = _.Disassemble(*inst); return _.diag(SPV_ERROR_INVALID_DATA, nullptr) << "Type " << _.getIdName(graph_type_inst->id()) << " for graph " << _.getIdName(graph_inst->id()) << " has " << graph_type_num_inputs << " inputs but found an OpGraphInputARM instruction with an " "InputIndex that is " << input_index << ": " << disassembly; } uint32_t graph_type_input_type = GraphTypeInstGetInputAtIndex(graph_type_inst, input_index); if (has_element_index) { // Check ElementIndex is allowed if (!IsTensorArray(_, graph_type_input_type)) { return _.diag(SPV_ERROR_INVALID_DATA, inst) << "OpGraphInputARM ElementIndex not allowed when the graph input " "selected by " << "InputIndex is not an OpTypeArray or OpTypeRuntimeArray"; } // Check ElementIndex is in range if it can be evaluated and the input is a // fixed-sized array whose Length can be evaluated uint64_t element_index; if (_.IsArrayType(graph_type_input_type) && _.EvalConstantValUint64(inst->GetOperandAs(3), &element_index)) { uint64_t array_length; auto graph_type_input_type_inst = _.FindDef(graph_type_input_type); if (_.EvalConstantValUint64( graph_type_input_type_inst->GetOperandAs(2), &array_length)) { if (element_index >= array_length) { return _.diag(SPV_ERROR_INVALID_DATA, inst) << "OpGraphInputARM ElementIndex out of range. The type of " "the graph input being accessed " << _.getIdName(graph_type_input_type) << " is an array of " << array_length << " elements but " << "ElementIndex is " << element_index; } } } } // Check result type matches with graph type if (has_element_index) { uint32_t expected_type = _.GetComponentType(graph_type_input_type); if (inst->type_id() != expected_type) { return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Result Type " << _.getIdName(inst->type_id()) << " of graph input instruction " << _.getIdName(inst->id()) << " does not match the component type " << _.getIdName(expected_type) << " of input " << input_index << " in the graph type."; } } else { if (inst->type_id() != graph_type_input_type) { return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Result Type " << _.getIdName(inst->type_id()) << " of graph input instruction " << _.getIdName(inst->id()) << " does not match the type " << _.getIdName(graph_type_input_type) << " of input " << input_index << " in the graph type."; } } return SPV_SUCCESS; } spv_result_t ValidateGraphSetOutput(ValidationState_t& _, const Instruction* inst) { // Check type of OutputIndex auto output_index_inst = _.FindDef(inst->GetOperandAs(1)); if (!output_index_inst || !_.IsIntScalarType(output_index_inst->type_id(), 32)) { return _.diag(SPV_ERROR_INVALID_DATA, inst) << spvOpcodeString(inst->opcode()) << " OutputIndex must be a 32-bit integer."; } bool has_element_index = inst->operands().size() > 2; // Check type of ElementIndex if (has_element_index) { auto element_index_inst = _.FindDef(inst->GetOperandAs(2)); if (!element_index_inst || !_.IsIntScalarType(element_index_inst->type_id(), 32)) { return _.diag(SPV_ERROR_INVALID_DATA, inst) << spvOpcodeString(inst->opcode()) << " ElementIndex must be a 32-bit integer."; } } // Find graph definition size_t inst_num = inst->LineNum() - 1; auto graph_inst = &_.ordered_instructions()[inst_num]; while (--inst_num) { graph_inst = &_.ordered_instructions()[inst_num]; if (graph_inst->opcode() == spv::Op::OpGraphARM) { break; } } // Can the OutputIndex be evaluated? // If not, there's nothing more we can validate here. uint64_t output_index; if (!_.EvalConstantValUint64(inst->GetOperandAs(1), &output_index)) { return SPV_SUCCESS; } // Check that the OutputIndex is valid with respect to the graph type auto graph_type_inst = _.FindDef(graph_inst->type_id()); size_t graph_type_num_outputs = GraphTypeInstNumOutputs(graph_type_inst); if (output_index >= graph_type_num_outputs) { return _.diag(SPV_ERROR_INVALID_DATA, inst) << spvOpcodeString(inst->opcode()) << " setting OutputIndex " << output_index << " but graph only has " << graph_type_num_outputs << " outputs."; } uint32_t graph_type_output_type = GraphTypeInstGetOutputAtIndex(graph_type_inst, output_index); if (has_element_index) { // Check ElementIndex is allowed if (!IsTensorArray(_, graph_type_output_type)) { return _.diag(SPV_ERROR_INVALID_DATA, inst) << "OpGraphSetOutputARM ElementIndex not allowed when the graph " "output selected by " << "OutputIndex is not an OpTypeArray or OpTypeRuntimeArray"; } // Check ElementIndex is in range if it can be evaluated and the output is a // fixed-sized array whose Length can be evaluated uint64_t element_index; if (_.IsArrayType(graph_type_output_type) && _.EvalConstantValUint64(inst->GetOperandAs(2), &element_index)) { uint64_t array_length; auto graph_type_output_type_inst = _.FindDef(graph_type_output_type); if (_.EvalConstantValUint64( graph_type_output_type_inst->GetOperandAs(2), &array_length)) { if (element_index >= array_length) { return _.diag(SPV_ERROR_INVALID_DATA, inst) << "OpGraphSetOutputARM ElementIndex out of range. The type " "of the graph output being accessed " << _.getIdName(graph_type_output_type) << " is an array of " << array_length << " elements but " << "ElementIndex is " << element_index; } } } } // Check Value's type matches with graph type uint32_t value = inst->GetOperandAs(0); uint32_t value_type = _.FindDef(value)->type_id(); if (has_element_index) { uint32_t expected_type = _.GetComponentType(graph_type_output_type); if (value_type != expected_type) { return _.diag(SPV_ERROR_INVALID_DATA, inst) << "The type " << _.getIdName(value_type) << " of Value provided to the graph output instruction " << _.getIdName(value) << " does not match the component type " << _.getIdName(expected_type) << " of output " << output_index << " in the graph type."; } } else { if (value_type != graph_type_output_type) { return _.diag(SPV_ERROR_INVALID_DATA, inst) << "The type " << _.getIdName(value_type) << " of Value provided to the graph output instruction " << _.getIdName(value) << " does not match the type " << _.getIdName(graph_type_output_type) << " of output " << output_index << " in the graph type."; } } return SPV_SUCCESS; } bool InputOutputInstructionsHaveDuplicateIndices( ValidationState_t& _, std::deque& inout_insts, const Instruction** first_dup) { std::set> inout_element_indices; for (auto const inst : inout_insts) { const bool is_input = inst->opcode() == spv::Op::OpGraphInputARM; bool has_element_index = inst->operands().size() > (is_input ? 3 : 2); uint64_t inout_index; if (!_.EvalConstantValUint64(inst->GetOperandAs(is_input ? 2 : 1), &inout_index)) { continue; } uint64_t element_index = -1; // -1 means no ElementIndex if (has_element_index) { if (!_.EvalConstantValUint64( inst->GetOperandAs(is_input ? 3 : 2), &element_index)) { continue; } } auto inout_element_pair = std::make_pair(inout_index, element_index); auto inout_noelement_pair = std::make_pair(inout_index, -1); if (inout_element_indices.count(inout_element_pair) || inout_element_indices.count(inout_noelement_pair)) { *first_dup = inst; return true; } inout_element_indices.insert(inout_element_pair); } return false; } spv_result_t ValidateGraphEnd(ValidationState_t& _, const Instruction* inst) { size_t end_inst_num = inst->LineNum() - 1; // Gather OpGraphInputARM and OpGraphSetOutputARM instructions std::deque graph_inputs, graph_outputs; size_t in_inst_num = end_inst_num; auto graph_inst = &_.ordered_instructions()[in_inst_num]; while (--in_inst_num) { graph_inst = &_.ordered_instructions()[in_inst_num]; if (graph_inst->opcode() == spv::Op::OpGraphInputARM) { graph_inputs.push_front(graph_inst); continue; } if (graph_inst->opcode() == spv::Op::OpGraphSetOutputARM) { graph_outputs.push_front(graph_inst); continue; } if (graph_inst->opcode() == spv::Op::OpGraphARM) { break; } } const Instruction* first_dup; // Check that there are no duplicate InputIndex and ElementIndex values if (InputOutputInstructionsHaveDuplicateIndices(_, graph_inputs, &first_dup)) { return _.diag(SPV_ERROR_INVALID_DATA, first_dup) << "Two OpGraphInputARM instructions with the same InputIndex " "must not be part of the same " << "graph definition unless ElementIndex is present in both with " "different values."; } // Check that there are no duplicate OutputIndex and ElementIndex values if (InputOutputInstructionsHaveDuplicateIndices(_, graph_outputs, &first_dup)) { return _.diag(SPV_ERROR_INVALID_DATA, first_dup) << "Two OpGraphSetOutputARM instructions with the same " "OutputIndex must not be part of the same " << "graph definition unless ElementIndex is present in both with " "different values."; } return SPV_SUCCESS; } } // namespace // Validates correctness of graph instructions. spv_result_t GraphPass(ValidationState_t& _, const Instruction* inst) { switch (inst->opcode()) { case spv::Op::OpTypeGraphARM: return ValidateGraphType(_, inst); case spv::Op::OpGraphConstantARM: return ValidateGraphConstant(_, inst); case spv::Op::OpGraphEntryPointARM: return ValidateGraphEntryPoint(_, inst); case spv::Op::OpGraphARM: return ValidateGraph(_, inst); case spv::Op::OpGraphInputARM: return ValidateGraphInput(_, inst); case spv::Op::OpGraphSetOutputARM: return ValidateGraphSetOutput(_, inst); case spv::Op::OpGraphEndARM: return ValidateGraphEnd(_, inst); default: break; } return SPV_SUCCESS; } } // namespace val } // namespace spvtools