From 631edf5e8dbf39ba904ee0e81200e4cd240eae2e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=D0=91=D1=80=D0=B0=D0=BD=D0=B8=D0=BC=D0=B8=D1=80=20=D0=9A?= =?UTF-8?q?=D0=B0=D1=80=D0=B0=D1=9F=D0=B8=D1=9B?= Date: Sat, 23 Nov 2019 08:55:57 -0800 Subject: [PATCH] Updated spirv-tools. --- .../include/generated/build-version.inc | 2 +- .../spirv-tools/source/opt/ir_context.cpp | 9 + .../source/opt/value_number_table.cpp | 2 +- 3rdparty/spirv-tools/source/opt/wrap_opkill.h | 3 +- .../test/fuzzers/spvtools_as_fuzzer.cpp | 4 +- .../spirv-tools/test/opt/value_table_test.cpp | 31 ++++ .../spirv-tools/test/tools/CMakeLists.txt | 9 +- 3rdparty/spirv-tools/test/tools/expect.py | 16 +- .../spirv-tools/test/tools/expect_nosetest.py | 80 --------- .../spirv-tools/test/tools/expect_unittest.py | 82 +++++++++ .../test/tools/spirv_test_framework.py | 31 +++- .../tools/spirv_test_framework_nosetest.py | 155 ----------------- .../tools/spirv_test_framework_unittest.py | 158 ++++++++++++++++++ 13 files changed, 326 insertions(+), 256 deletions(-) delete mode 100755 3rdparty/spirv-tools/test/tools/expect_nosetest.py create mode 100644 3rdparty/spirv-tools/test/tools/expect_unittest.py delete mode 100755 3rdparty/spirv-tools/test/tools/spirv_test_framework_nosetest.py create mode 100644 3rdparty/spirv-tools/test/tools/spirv_test_framework_unittest.py diff --git a/3rdparty/spirv-tools/include/generated/build-version.inc b/3rdparty/spirv-tools/include/generated/build-version.inc index b1f14bd43..578d204ee 100644 --- a/3rdparty/spirv-tools/include/generated/build-version.inc +++ b/3rdparty/spirv-tools/include/generated/build-version.inc @@ -1 +1 @@ -"v2019.5-dev", "SPIRV-Tools v2019.5-dev v2019.4-172-gc3f22f7c" +"v2019.5-dev", "SPIRV-Tools v2019.5-dev v2019.4-178-g85f3e93d" diff --git a/3rdparty/spirv-tools/source/opt/ir_context.cpp b/3rdparty/spirv-tools/source/opt/ir_context.cpp index d940180da..7bca29b20 100644 --- a/3rdparty/spirv-tools/source/opt/ir_context.cpp +++ b/3rdparty/spirv-tools/source/opt/ir_context.cpp @@ -273,6 +273,14 @@ bool IRContext::IsConsistent() { } } + if (AreAnalysesValid(kAnalysisIdToFuncMapping)) { + for (auto& fn : *module_) { + if (id_to_func_[fn.result_id()] != &fn) { + return false; + } + } + } + if (AreAnalysesValid(kAnalysisInstrToBlockMapping)) { for (auto& func : *module()) { for (auto& block : func) { @@ -818,6 +826,7 @@ bool IRContext::ProcessCallTreeFromRoots(ProcessFunction& pfn, roots->pop(); if (done.insert(fi).second) { Function* fn = GetFunction(fi); + assert(fn && "Trying to process a function that does not exist."); modified = pfn(fn) || modified; AddCalls(fn, roots); } diff --git a/3rdparty/spirv-tools/source/opt/value_number_table.cpp b/3rdparty/spirv-tools/source/opt/value_number_table.cpp index 8df34ef5a..82549a6dc 100644 --- a/3rdparty/spirv-tools/source/opt/value_number_table.cpp +++ b/3rdparty/spirv-tools/source/opt/value_number_table.cpp @@ -93,7 +93,7 @@ uint32_t ValueNumberTable::AssignValueNumber(Instruction* inst) { // Phi nodes are a type of copy. If all of the inputs have the same value // number, then we can assign the result of the phi the same value number. - if (inst->opcode() == SpvOpPhi && + if (inst->opcode() == SpvOpPhi && inst->NumInOperands() > 0 && dec_mgr->HaveTheSameDecorations(inst->result_id(), inst->GetSingleWordInOperand(0))) { value = GetValueNumber(inst->GetSingleWordInOperand(0)); diff --git a/3rdparty/spirv-tools/source/opt/wrap_opkill.h b/3rdparty/spirv-tools/source/opt/wrap_opkill.h index 87a5d692c..09f2dfafd 100644 --- a/3rdparty/spirv-tools/source/opt/wrap_opkill.h +++ b/3rdparty/spirv-tools/source/opt/wrap_opkill.h @@ -34,8 +34,7 @@ class WrapOpKill : public Pass { IRContext::kAnalysisInstrToBlockMapping | IRContext::kAnalysisDecorations | IRContext::kAnalysisCombinators | IRContext::kAnalysisNameMap | IRContext::kAnalysisBuiltinVarId | - IRContext::kAnalysisIdToFuncMapping | IRContext::kAnalysisConstants | - IRContext::kAnalysisTypes; + IRContext::kAnalysisConstants | IRContext::kAnalysisTypes; } private: diff --git a/3rdparty/spirv-tools/test/fuzzers/spvtools_as_fuzzer.cpp b/3rdparty/spirv-tools/test/fuzzers/spvtools_as_fuzzer.cpp index 1b1de0082..8cecb05f5 100644 --- a/3rdparty/spirv-tools/test/fuzzers/spvtools_as_fuzzer.cpp +++ b/3rdparty/spirv-tools/test/fuzzers/spvtools_as_fuzzer.cpp @@ -36,7 +36,7 @@ extern "C" int LLVMFuzzerTestOneInput(const uint8_t* data, size_t size) { input_str.resize(char_count); memcpy(input_str.data(), input.data(), input.size() * sizeof(uint32_t)); - spv_binary binary; + spv_binary binary = nullptr; spv_diagnostic diagnostic = nullptr; spvTextToBinaryWithOptions(context, input_str.data(), input_str.size(), SPV_TEXT_TO_BINARY_OPTION_NONE, &binary, @@ -66,5 +66,7 @@ extern "C" int LLVMFuzzerTestOneInput(const uint8_t* data, size_t size) { binary = nullptr; } + spvContextDestroy(context); + return 0; } diff --git a/3rdparty/spirv-tools/test/opt/value_table_test.cpp b/3rdparty/spirv-tools/test/opt/value_table_test.cpp index 0b7530c08..a0942ccdc 100644 --- a/3rdparty/spirv-tools/test/opt/value_table_test.cpp +++ b/3rdparty/spirv-tools/test/opt/value_table_test.cpp @@ -653,6 +653,37 @@ TEST_F(ValueTableTest, PhiLoopTest) { EXPECT_NE(vtable.GetValueNumber(phi1), vtable.GetValueNumber(phi2)); } +// Test to make sure that OpPhi instructions with no in operands are handled +// correctly. +TEST_F(ValueTableTest, EmptyPhiTest) { + const std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %2 "main" + OpExecutionMode %2 OriginUpperLeft + OpSource GLSL 430 + %void = OpTypeVoid + %4 = OpTypeFunction %void + %bool = OpTypeBool + %true = OpConstantTrue %bool + %2 = OpFunction %void None %4 + %7 = OpLabel + OpSelectionMerge %8 None + OpBranchConditional %true %9 %8 + %9 = OpLabel + OpKill + %8 = OpLabel + %10 = OpPhi %bool + OpReturn + OpFunctionEnd + )"; + auto context = BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text); + ValueNumberTable vtable(context.get()); + Instruction* inst = context->get_def_use_mgr()->GetDef(10); + vtable.GetValueNumber(inst); +} + } // namespace } // namespace opt } // namespace spvtools diff --git a/3rdparty/spirv-tools/test/tools/CMakeLists.txt b/3rdparty/spirv-tools/test/tools/CMakeLists.txt index cee95cadb..99f9780c5 100644 --- a/3rdparty/spirv-tools/test/tools/CMakeLists.txt +++ b/3rdparty/spirv-tools/test/tools/CMakeLists.txt @@ -12,7 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -spirv_add_nosetests(expect) -spirv_add_nosetests(spirv_test_framework) - +add_test(NAME spirv-tools_expect_unittests + COMMAND ${PYTHON_EXECUTABLE} -m unittest expect_unittest.py + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}) +add_test(NAME spirv-tools_spirv_test_framework_unittests + COMMAND ${PYTHON_EXECUTABLE} -m unittest spirv_test_framework_unittest.py + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}) add_subdirectory(opt) diff --git a/3rdparty/spirv-tools/test/tools/expect.py b/3rdparty/spirv-tools/test/tools/expect.py index 52999ce54..0b51adc9c 100755 --- a/3rdparty/spirv-tools/test/tools/expect.py +++ b/3rdparty/spirv-tools/test/tools/expect.py @@ -510,7 +510,7 @@ class ErrorMessageSubstr(SpirvTest): if not status.stderr: return False, 'Expected error message, but no output on stderr' if self.expected_error_substr not in convert_to_unix_line_endings( - status.stderr.decode('utf8')): + status.stderr): return False, ('Incorrect stderr output:\n{act}\n' 'Expected substring not found in stderr:\n{exp}'.format( act=status.stderr, exp=self.expected_error_substr)) @@ -530,7 +530,7 @@ class WarningMessage(SpirvTest): ' command execution') if not status.stderr: return False, 'Expected warning message, but no output on stderr' - if self.expected_warning != convert_to_unix_line_endings(status.stderr.decode('utf8')): + if self.expected_warning != convert_to_unix_line_endings(status.stderr): return False, ('Incorrect stderr output:\n{act}\n' 'Expected:\n{exp}'.format( act=status.stderr, exp=self.expected_warning)) @@ -590,16 +590,16 @@ class StdoutMatch(SpirvTest): if not status.stdout: return False, 'Expected something on stdout' elif type(self.expected_stdout) == str: - if self.expected_stdout != convert_to_unix_line_endings(status.stdout.decode('utf8')): + if self.expected_stdout != convert_to_unix_line_endings(status.stdout): return False, ('Incorrect stdout output:\n{ac}\n' 'Expected:\n{ex}'.format( ac=status.stdout, ex=self.expected_stdout)) else: - converted = convert_to_unix_line_endings(status.stdout.decode('utf8')) + converted = convert_to_unix_line_endings(status.stdout) if not self.expected_stdout.search(converted): return False, ('Incorrect stdout output:\n{ac}\n' 'Expected to match regex:\n{ex}'.format( - ac=status.stdout.decode('utf8'), ex=self.expected_stdout.pattern)) + ac=status.stdout, ex=self.expected_stdout.pattern)) return True, '' @@ -624,13 +624,13 @@ class StderrMatch(SpirvTest): if not status.stderr: return False, 'Expected something on stderr' elif type(self.expected_stderr) == str: - if self.expected_stderr != convert_to_unix_line_endings(status.stderr.decode('utf8')): + if self.expected_stderr != convert_to_unix_line_endings(status.stderr): return False, ('Incorrect stderr output:\n{ac}\n' 'Expected:\n{ex}'.format( ac=status.stderr, ex=self.expected_stderr)) else: if not self.expected_stderr.search( - convert_to_unix_line_endings(status.stderr.decode('utf8'))): + convert_to_unix_line_endings(status.stderr)): return False, ('Incorrect stderr output:\n{ac}\n' 'Expected to match regex:\n{ex}'.format( ac=status.stderr, ex=self.expected_stderr.pattern)) @@ -695,7 +695,7 @@ class ExecutedListOfPasses(SpirvTest): # Collect all the output lines containing a pass name. pass_names = [] pass_name_re = re.compile(r'.*IR before pass (?P[\S]+)') - for line in status.stderr.decode('utf8').splitlines(): + for line in status.stderr.splitlines(): match = pass_name_re.match(line) if match: pass_names.append(match.group('pass_name')) diff --git a/3rdparty/spirv-tools/test/tools/expect_nosetest.py b/3rdparty/spirv-tools/test/tools/expect_nosetest.py deleted file mode 100755 index b591a2d07..000000000 --- a/3rdparty/spirv-tools/test/tools/expect_nosetest.py +++ /dev/null @@ -1,80 +0,0 @@ -# Copyright (c) 2018 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Tests for the expect module.""" - -import expect -from spirv_test_framework import TestStatus -from nose.tools import assert_equal, assert_true, assert_false -import re - - -def nosetest_get_object_name(): - """Tests get_object_filename().""" - source_and_object_names = [('a.vert', 'a.vert.spv'), ('b.frag', 'b.frag.spv'), - ('c.tesc', 'c.tesc.spv'), ('d.tese', 'd.tese.spv'), - ('e.geom', 'e.geom.spv'), ('f.comp', 'f.comp.spv'), - ('file', 'file.spv'), ('file.', 'file.spv'), - ('file.uk', - 'file.spv'), ('file.vert.', - 'file.vert.spv'), ('file.vert.bla', - 'file.vert.spv')] - actual_object_names = [ - expect.get_object_filename(f[0]) for f in source_and_object_names - ] - expected_object_names = [f[1] for f in source_and_object_names] - - assert_equal(actual_object_names, expected_object_names) - - -class TestStdoutMatchADotC(expect.StdoutMatch): - expected_stdout = re.compile('a.c') - - -def nosetest_stdout_match_regex_has_match(): - test = TestStdoutMatchADotC() - status = TestStatus( - test_manager=None, - returncode=0, - stdout='0abc1', - stderr=None, - directory=None, - inputs=None, - input_filenames=None) - assert_true(test.check_stdout_match(status)[0]) - - -def nosetest_stdout_match_regex_no_match(): - test = TestStdoutMatchADotC() - status = TestStatus( - test_manager=None, - returncode=0, - stdout='ab', - stderr=None, - directory=None, - inputs=None, - input_filenames=None) - assert_false(test.check_stdout_match(status)[0]) - - -def nosetest_stdout_match_regex_empty_stdout(): - test = TestStdoutMatchADotC() - status = TestStatus( - test_manager=None, - returncode=0, - stdout='', - stderr=None, - directory=None, - inputs=None, - input_filenames=None) - assert_false(test.check_stdout_match(status)[0]) diff --git a/3rdparty/spirv-tools/test/tools/expect_unittest.py b/3rdparty/spirv-tools/test/tools/expect_unittest.py new file mode 100644 index 000000000..a28de1b97 --- /dev/null +++ b/3rdparty/spirv-tools/test/tools/expect_unittest.py @@ -0,0 +1,82 @@ +# Copyright (c) 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for the expect module.""" + +import expect +from spirv_test_framework import TestStatus +import re +import unittest + + +class TestStdoutMatchADotC(expect.StdoutMatch): + expected_stdout = re.compile('a.c') + + +class TestExpect(unittest.TestCase): + def test_get_object_name(self): + """Tests get_object_filename().""" + source_and_object_names = [('a.vert', 'a.vert.spv'), + ('b.frag', 'b.frag.spv'), + ('c.tesc', 'c.tesc.spv'), + ('d.tese', 'd.tese.spv'), + ('e.geom', 'e.geom.spv'), + ('f.comp', 'f.comp.spv'), + ('file', 'file.spv'), ('file.', 'file.spv'), + ('file.uk', + 'file.spv'), ('file.vert.', + 'file.vert.spv'), + ('file.vert.bla', + 'file.vert.spv')] + actual_object_names = [ + expect.get_object_filename(f[0]) for f in source_and_object_names + ] + expected_object_names = [f[1] for f in source_and_object_names] + + self.assertEqual(actual_object_names, expected_object_names) + + def test_stdout_match_regex_has_match(self): + test = TestStdoutMatchADotC() + status = TestStatus( + test_manager=None, + returncode=0, + stdout=b'0abc1', + stderr=None, + directory=None, + inputs=None, + input_filenames=None) + self.assertTrue(test.check_stdout_match(status)[0]) + + def test_stdout_match_regex_no_match(self): + test = TestStdoutMatchADotC() + status = TestStatus( + test_manager=None, + returncode=0, + stdout=b'ab', + stderr=None, + directory=None, + inputs=None, + input_filenames=None) + self.assertFalse(test.check_stdout_match(status)[0]) + + def test_stdout_match_regex_empty_stdout(self): + test = TestStdoutMatchADotC() + status = TestStatus( + test_manager=None, + returncode=0, + stdout=b'', + stderr=None, + directory=None, + inputs=None, + input_filenames=None) + self.assertFalse(test.check_stdout_match(status)[0]) diff --git a/3rdparty/spirv-tools/test/tools/spirv_test_framework.py b/3rdparty/spirv-tools/test/tools/spirv_test_framework.py index d8d64f3e4..42f83c64a 100755 --- a/3rdparty/spirv-tools/test/tools/spirv_test_framework.py +++ b/3rdparty/spirv-tools/test/tools/spirv_test_framework.py @@ -70,7 +70,7 @@ def get_all_methods(instance): def get_all_superclasses(cls): - """Returns all superclasses of a given class. + """Returns all superclasses of a given class. Omits root 'object' superclass. Returns: A list of superclasses of the given class. The order guarantees that @@ -83,11 +83,12 @@ def get_all_superclasses(cls): classes = [] for superclass in cls.__bases__: for c in get_all_superclasses(superclass): - if c not in classes: + if c is not object and c not in classes: classes.append(c) for superclass in cls.__bases__: - if superclass not in classes: + if superclass is not object and superclass not in classes: classes.append(superclass) + return classes @@ -142,8 +143,28 @@ class TestStatus: inputs, input_filenames): self.test_manager = test_manager self.returncode = returncode - self.stdout = stdout - self.stderr = stderr + # Some of our MacOS bots still run Python 2, so need to be backwards + # compatible here. + if type(stdout) is not str: + if sys.version_info[0] is 2: + self.stdout = stdout.decode('utf-8') + elif sys.version_info[0] is 3: + self.stdout = str(stdout, encoding='utf-8') if stdout is not None else stdout + else: + raise Exception('Unable to determine if running Python 2 or 3 from {}'.format(sys.version_info)) + else: + self.stdout = stdout + + if type(stderr) is not str: + if sys.version_info[0] is 2: + self.stderr = stderr.decode('utf-8') + elif sys.version_info[0] is 3: + self.stderr = str(stderr, encoding='utf-8') if stderr is not None else stderr + else: + raise Exception('Unable to determine if running Python 2 or 3 from {}'.format(sys.version_info)) + else: + self.stderr = stderr + # temporary directory where the test runs self.directory = directory # List of inputs, as PlaceHolder objects. diff --git a/3rdparty/spirv-tools/test/tools/spirv_test_framework_nosetest.py b/3rdparty/spirv-tools/test/tools/spirv_test_framework_nosetest.py deleted file mode 100755 index c0fbed581..000000000 --- a/3rdparty/spirv-tools/test/tools/spirv_test_framework_nosetest.py +++ /dev/null @@ -1,155 +0,0 @@ -# Copyright (c) 2018 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from spirv_test_framework import get_all_test_methods, get_all_superclasses -from nose.tools import assert_equal, with_setup - - -# Classes to be used in testing get_all_{superclasses|test_methods}() -class Root: - - def check_root(self): - pass - - -class A(Root): - - def check_a(self): - pass - - -class B(Root): - - def check_b(self): - pass - - -class C(Root): - - def check_c(self): - pass - - -class D(Root): - - def check_d(self): - pass - - -class E(Root): - - def check_e(self): - pass - - -class H(B, C, D): - - def check_h(self): - pass - - -class I(E): - - def check_i(self): - pass - - -class O(H, I): - - def check_o(self): - pass - - -class U(A, O): - - def check_u(self): - pass - - -class X(U, A): - - def check_x(self): - pass - - -class R1: - - def check_r1(self): - pass - - -class R2: - - def check_r2(self): - pass - - -class Multi(R1, R2): - - def check_multi(self): - pass - - -def nosetest_get_all_superclasses(): - """Tests get_all_superclasses().""" - - assert_equal(get_all_superclasses(A), [Root]) - assert_equal(get_all_superclasses(B), [Root]) - assert_equal(get_all_superclasses(C), [Root]) - assert_equal(get_all_superclasses(D), [Root]) - assert_equal(get_all_superclasses(E), [Root]) - - assert_equal(get_all_superclasses(H), [Root, B, C, D]) - assert_equal(get_all_superclasses(I), [Root, E]) - - assert_equal(get_all_superclasses(O), [Root, B, C, D, E, H, I]) - - assert_equal(get_all_superclasses(U), [Root, B, C, D, E, H, I, A, O]) - assert_equal(get_all_superclasses(X), [Root, B, C, D, E, H, I, A, O, U]) - - assert_equal(get_all_superclasses(Multi), [R1, R2]) - - -def nosetest_get_all_methods(): - """Tests get_all_test_methods().""" - assert_equal(get_all_test_methods(A), ['check_root', 'check_a']) - assert_equal(get_all_test_methods(B), ['check_root', 'check_b']) - assert_equal(get_all_test_methods(C), ['check_root', 'check_c']) - assert_equal(get_all_test_methods(D), ['check_root', 'check_d']) - assert_equal(get_all_test_methods(E), ['check_root', 'check_e']) - - assert_equal( - get_all_test_methods(H), - ['check_root', 'check_b', 'check_c', 'check_d', 'check_h']) - assert_equal(get_all_test_methods(I), ['check_root', 'check_e', 'check_i']) - - assert_equal( - get_all_test_methods(O), [ - 'check_root', 'check_b', 'check_c', 'check_d', 'check_e', 'check_h', - 'check_i', 'check_o' - ]) - - assert_equal( - get_all_test_methods(U), [ - 'check_root', 'check_b', 'check_c', 'check_d', 'check_e', 'check_h', - 'check_i', 'check_a', 'check_o', 'check_u' - ]) - assert_equal( - get_all_test_methods(X), [ - 'check_root', 'check_b', 'check_c', 'check_d', 'check_e', 'check_h', - 'check_i', 'check_a', 'check_o', 'check_u', 'check_x' - ]) - - assert_equal( - get_all_test_methods(Multi), ['check_r1', 'check_r2', 'check_multi']) diff --git a/3rdparty/spirv-tools/test/tools/spirv_test_framework_unittest.py b/3rdparty/spirv-tools/test/tools/spirv_test_framework_unittest.py new file mode 100644 index 000000000..e64e86c01 --- /dev/null +++ b/3rdparty/spirv-tools/test/tools/spirv_test_framework_unittest.py @@ -0,0 +1,158 @@ +# Copyright (c) 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for the spirv test framework module.""" + +from spirv_test_framework import get_all_test_methods, get_all_superclasses +import unittest + +# Classes to be used in testing get_all_{superclasses|test_methods}() + + +class Root: + + def check_root(self): + pass + + +class A(Root): + + def check_a(self): + pass + + +class B(Root): + + def check_b(self): + pass + + +class C(Root): + + def check_c(self): + pass + + +class D(Root): + + def check_d(self): + pass + + +class E(Root): + + def check_e(self): + pass + + +class H(B, C, D): + + def check_h(self): + pass + + +class I(E): + + def check_i(self): + pass + + +class O(H, I): + + def check_o(self): + pass + + +class U(A, O): + + def check_u(self): + pass + + +class X(U, A): + + def check_x(self): + pass + + +class R1: + + def check_r1(self): + pass + + +class R2: + + def check_r2(self): + pass + + +class Multi(R1, R2): + + def check_multi(self): + pass + + +class TestSpirvTestFramework(unittest.TestCase): + def test_get_all_superclasses(self): + self.assertEqual(get_all_superclasses(A), [Root]) + self.assertEqual(get_all_superclasses(B), [Root]) + self.assertEqual(get_all_superclasses(C), [Root]) + self.assertEqual(get_all_superclasses(D), [Root]) + self.assertEqual(get_all_superclasses(E), [Root]) + + self.assertEqual(get_all_superclasses(H), [Root, B, C, D]) + self.assertEqual(get_all_superclasses(I), [Root, E]) + + self.assertEqual(get_all_superclasses(O), [Root, B, C, D, E, H, I]) + + self.assertEqual(get_all_superclasses( + U), [Root, B, C, D, E, H, I, A, O]) + self.assertEqual(get_all_superclasses( + X), [Root, B, C, D, E, H, I, A, O, U]) + + self.assertEqual(get_all_superclasses(Multi), [R1, R2]) + + def test_get_all_methods(self): + self.assertEqual(get_all_test_methods(A), ['check_root', 'check_a']) + self.assertEqual(get_all_test_methods(B), ['check_root', 'check_b']) + self.assertEqual(get_all_test_methods(C), ['check_root', 'check_c']) + self.assertEqual(get_all_test_methods(D), ['check_root', 'check_d']) + self.assertEqual(get_all_test_methods(E), ['check_root', 'check_e']) + + self.assertEqual( + get_all_test_methods(H), + ['check_root', 'check_b', 'check_c', 'check_d', 'check_h']) + self.assertEqual(get_all_test_methods( + I), ['check_root', 'check_e', 'check_i']) + + self.assertEqual( + get_all_test_methods(O), [ + 'check_root', 'check_b', 'check_c', 'check_d', 'check_e', 'check_h', + 'check_i', 'check_o' + ]) + + self.assertEqual( + get_all_test_methods(U), [ + 'check_root', 'check_b', 'check_c', 'check_d', 'check_e', 'check_h', + 'check_i', 'check_a', 'check_o', 'check_u' + ]) + + self.assertEqual( + get_all_test_methods(X), [ + 'check_root', 'check_b', 'check_c', 'check_d', 'check_e', 'check_h', + 'check_i', 'check_a', 'check_o', 'check_u', 'check_x' + ]) + + self.assertEqual( + get_all_test_methods(Multi), ['check_r1', 'check_r2', 'check_multi'])