Updated spirv-tools.

This commit is contained in:
Бранимир Караџић
2024-12-28 22:40:57 -08:00
parent af34836458
commit b0ef2b8c4b
51 changed files with 1466 additions and 138 deletions

View File

@@ -49,7 +49,7 @@ spv_result_t ConversionPass(ValidationState_t& _, const Instruction* inst) {
if (_.IsCooperativeMatrixType(result_type) ||
_.IsCooperativeMatrixType(input_type)) {
spv_result_t ret =
_.CooperativeMatrixShapesMatch(inst, result_type, input_type);
_.CooperativeMatrixShapesMatch(inst, result_type, input_type, true);
if (ret != SPV_SUCCESS) return ret;
} else {
if (_.GetDimension(result_type) != _.GetDimension(input_type))
@@ -79,7 +79,7 @@ spv_result_t ConversionPass(ValidationState_t& _, const Instruction* inst) {
if (_.IsCooperativeMatrixType(result_type) ||
_.IsCooperativeMatrixType(input_type)) {
spv_result_t ret =
_.CooperativeMatrixShapesMatch(inst, result_type, input_type);
_.CooperativeMatrixShapesMatch(inst, result_type, input_type, true);
if (ret != SPV_SUCCESS) return ret;
} else {
if (_.GetDimension(result_type) != _.GetDimension(input_type))
@@ -111,7 +111,7 @@ spv_result_t ConversionPass(ValidationState_t& _, const Instruction* inst) {
if (_.IsCooperativeMatrixType(result_type) ||
_.IsCooperativeMatrixType(input_type)) {
spv_result_t ret =
_.CooperativeMatrixShapesMatch(inst, result_type, input_type);
_.CooperativeMatrixShapesMatch(inst, result_type, input_type, true);
if (ret != SPV_SUCCESS) return ret;
} else {
if (_.GetDimension(result_type) != _.GetDimension(input_type))
@@ -142,7 +142,7 @@ spv_result_t ConversionPass(ValidationState_t& _, const Instruction* inst) {
if (_.IsCooperativeMatrixType(result_type) ||
_.IsCooperativeMatrixType(input_type)) {
spv_result_t ret =
_.CooperativeMatrixShapesMatch(inst, result_type, input_type);
_.CooperativeMatrixShapesMatch(inst, result_type, input_type, true);
if (ret != SPV_SUCCESS) return ret;
} else {
if (_.GetDimension(result_type) != _.GetDimension(input_type))
@@ -177,7 +177,7 @@ spv_result_t ConversionPass(ValidationState_t& _, const Instruction* inst) {
if (_.IsCooperativeMatrixType(result_type) ||
_.IsCooperativeMatrixType(input_type)) {
spv_result_t ret =
_.CooperativeMatrixShapesMatch(inst, result_type, input_type);
_.CooperativeMatrixShapesMatch(inst, result_type, input_type, true);
if (ret != SPV_SUCCESS) return ret;
} else {
if (_.GetDimension(result_type) != _.GetDimension(input_type))
@@ -213,7 +213,7 @@ spv_result_t ConversionPass(ValidationState_t& _, const Instruction* inst) {
if (_.IsCooperativeMatrixType(result_type) ||
_.IsCooperativeMatrixType(input_type)) {
spv_result_t ret =
_.CooperativeMatrixShapesMatch(inst, result_type, input_type);
_.CooperativeMatrixShapesMatch(inst, result_type, input_type, true);
if (ret != SPV_SUCCESS) return ret;
} else {
if (_.GetDimension(result_type) != _.GetDimension(input_type))
@@ -497,8 +497,8 @@ spv_result_t ConversionPass(ValidationState_t& _, const Instruction* inst) {
<< "matrix: " << spvOpcodeString(opcode);
if (result_is_coopmat) {
spv_result_t ret =
_.CooperativeMatrixShapesMatch(inst, result_type, input_type);
spv_result_t ret = _.CooperativeMatrixShapesMatch(inst, result_type,
input_type, false);
if (ret != SPV_SUCCESS) return ret;
}
@@ -568,6 +568,43 @@ spv_result_t ConversionPass(ValidationState_t& _, const Instruction* inst) {
break;
}
case spv::Op::OpCooperativeMatrixConvertNV:
case spv::Op::OpCooperativeMatrixTransposeNV: {
if (!_.IsCooperativeMatrixType(result_type)) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Expected cooperative matrix Result Type: "
<< spvOpcodeString(opcode);
}
const uint32_t input_type = _.GetOperandTypeId(inst, 2);
if (!_.IsCooperativeMatrixType(input_type)) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Expected cooperative matrix type for Matrix input: "
<< spvOpcodeString(opcode);
}
bool swap_row_col = (opcode == spv::Op::OpCooperativeMatrixTransposeNV);
if (auto error = _.CooperativeMatrixShapesMatch(
inst, result_type, input_type, true, swap_row_col))
return error;
if (opcode == spv::Op::OpCooperativeMatrixConvertNV) {
if (_.FindDef(result_type)->GetOperandAs<uint32_t>(1) !=
_.FindDef(input_type)->GetOperandAs<uint32_t>(1)) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Result Type and Matrix component types mismatch: "
<< spvOpcodeString(opcode);
}
}
if (opcode == spv::Op::OpCooperativeMatrixTransposeNV) {
if (!_.IsCooperativeMatrixBType(result_type)) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Result Type must have UseB: " << spvOpcodeString(opcode);
}
}
break;
}
default:
break;
}