diff options
author | kennyballou <kballou@onyx.boisestate.edu> | 2013-06-21 17:05:57 -0600 |
---|---|---|
committer | kennyballou <kballou@onyx.boisestate.edu> | 2013-06-21 17:07:40 -0600 |
commit | f10e7e3b1588fc3b2d29cfb3a6e9ff0afb0b685c (patch) | |
tree | bf94afbb930997b049b8c54b8148951141a9da35 | |
parent | fbe99bf593c0b0fb12c43bcf2c7b9b2c50cb850b (diff) | |
download | xnt-f10e7e3b1588fc3b2d29cfb3a6e9ff0afb0b685c.tar.gz xnt-f10e7e3b1588fc3b2d29cfb3a6e9ff0afb0b685c.tar.xz |
Add NVCC compiler to xnt.build.cc
-rw-r--r-- | xnt/build/cc.py | 21 | ||||
-rw-r--r-- | xnt/tests/compilercollectiontests.py | 39 |
2 files changed, 59 insertions, 1 deletions
diff --git a/xnt/build/cc.py b/xnt/build/cc.py index 5cb455f..967608f 100644 --- a/xnt/build/cc.py +++ b/xnt/build/cc.py @@ -22,7 +22,7 @@ Definition of commonly used compilers import os import logging -from xnt.tasks import call +from xnt.tasks import call, which LOGGER = logging.getLogger(__name__) @@ -70,6 +70,25 @@ def javac(src, flags=None): cmd = __generate_command(src, flags, "javac") return __compile(cmd) +def nvcc(src, flags=None): + """NVCC: compile CUDA C/C++ programs + + :param src: CUDA source file to compile with default `nvcc` + :param flags: List of flags to pass onto the compiler + """ + assert which('nvcc') + return _gcc(src, flags, compiler='nvcc') + +def nvcc_o(src, output, flags=None): + """NVCC: compile with named output + + :param src: CUDA source file to compile with default `nvcc` + :param output: Name of resulting object or executable + :param flags: List of flags to pass onto the compiler + """ + assert which('nvcc') + return _gcc_o(src, output, flags, compiler='nvcc') + def _gcc(src, flags=None, compiler="gcc"): """Compile using gcc""" LOGGER.info("Compiling %s", src) diff --git a/xnt/tests/compilercollectiontests.py b/xnt/tests/compilercollectiontests.py index be72c81..cb5b3fc 100644 --- a/xnt/tests/compilercollectiontests.py +++ b/xnt/tests/compilercollectiontests.py @@ -88,6 +88,45 @@ class GppTests(unittest.TestCase): cc.gpp_o("temp/hello.cpp", "temp/hello") self.assertTrue(os.path.isfile("temp/hello")) +@unittest.skipUnless(xnt.in_path('nvcc'), 'nvcc is not in your path') +class NvccTests(unittest.TestCase): + """Test NVCC""" + def setUp(self): + """Test Case Setup""" + os.mkdir("temp") + with open("temp/hello.cu", "w") as test_code: + test_code.write(""" + __global__ void kernel(float *x) { + int idx = threadIdx.x; + x[idx] = 42; + } + int main() { + int size = sizeof(float) * 128; + float *x = (float*)malloc(size); + float *dev_x; + cudaMalloc((void**)&dev_x, size); + cudaMemcpy(dev_x, x, size, cudaMemcpyHostToDevice); + kernel<<<128, 1>>>(dev_x); + cudaMemcpy(x, dev_x, size, cudaMemcpyDeviceToHost); + cudaFree(dev_x); + }""") + def tearDown(self): + """Test Case Teardown""" + shutil.rmtree("temp") + + def test_nvcc(self): + """Test Default NVCC""" + cwd = os.getcwd() + os.chdir("temp") + cc.nvcc("hello.cu") + self.assertTrue(os.path.isfile("a.out")) + os.chdir(cwd) + + def test_nvcc_o(self): + """Test Named Output NVCC""" + cc.nvcc_o("temp/hello.cu", "temp/hello") + self.assertTrue(os.path.isfile("temp/hello")) + @unittest.skipUnless(xnt.in_path("javac"), "javac is not in your path") class JavacTests(unittest.TestCase): """Test Javac""" |