summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorkennyballou <kballou@onyx.boisestate.edu>2013-06-21 17:05:57 -0600
committerkennyballou <kballou@onyx.boisestate.edu>2013-06-21 17:07:40 -0600
commitf10e7e3b1588fc3b2d29cfb3a6e9ff0afb0b685c (patch)
treebf94afbb930997b049b8c54b8148951141a9da35
parentfbe99bf593c0b0fb12c43bcf2c7b9b2c50cb850b (diff)
downloadxnt-f10e7e3b1588fc3b2d29cfb3a6e9ff0afb0b685c.tar.gz
xnt-f10e7e3b1588fc3b2d29cfb3a6e9ff0afb0b685c.tar.xz
Add NVCC compiler to xnt.build.cc
-rw-r--r--xnt/build/cc.py21
-rw-r--r--xnt/tests/compilercollectiontests.py39
2 files changed, 59 insertions, 1 deletions
diff --git a/xnt/build/cc.py b/xnt/build/cc.py
index 5cb455f..967608f 100644
--- a/xnt/build/cc.py
+++ b/xnt/build/cc.py
@@ -22,7 +22,7 @@ Definition of commonly used compilers
import os
import logging
-from xnt.tasks import call
+from xnt.tasks import call, which
LOGGER = logging.getLogger(__name__)
@@ -70,6 +70,25 @@ def javac(src, flags=None):
cmd = __generate_command(src, flags, "javac")
return __compile(cmd)
+def nvcc(src, flags=None):
+ """NVCC: compile CUDA C/C++ programs
+
+ :param src: CUDA source file to compile with default `nvcc`
+ :param flags: List of flags to pass onto the compiler
+ """
+ assert which('nvcc')
+ return _gcc(src, flags, compiler='nvcc')
+
+def nvcc_o(src, output, flags=None):
+ """NVCC: compile with named output
+
+ :param src: CUDA source file to compile with default `nvcc`
+ :param output: Name of resulting object or executable
+ :param flags: List of flags to pass onto the compiler
+ """
+ assert which('nvcc')
+ return _gcc_o(src, output, flags, compiler='nvcc')
+
def _gcc(src, flags=None, compiler="gcc"):
"""Compile using gcc"""
LOGGER.info("Compiling %s", src)
diff --git a/xnt/tests/compilercollectiontests.py b/xnt/tests/compilercollectiontests.py
index be72c81..cb5b3fc 100644
--- a/xnt/tests/compilercollectiontests.py
+++ b/xnt/tests/compilercollectiontests.py
@@ -88,6 +88,45 @@ class GppTests(unittest.TestCase):
cc.gpp_o("temp/hello.cpp", "temp/hello")
self.assertTrue(os.path.isfile("temp/hello"))
+@unittest.skipUnless(xnt.in_path('nvcc'), 'nvcc is not in your path')
+class NvccTests(unittest.TestCase):
+ """Test NVCC"""
+ def setUp(self):
+ """Test Case Setup"""
+ os.mkdir("temp")
+ with open("temp/hello.cu", "w") as test_code:
+ test_code.write("""
+ __global__ void kernel(float *x) {
+ int idx = threadIdx.x;
+ x[idx] = 42;
+ }
+ int main() {
+ int size = sizeof(float) * 128;
+ float *x = (float*)malloc(size);
+ float *dev_x;
+ cudaMalloc((void**)&dev_x, size);
+ cudaMemcpy(dev_x, x, size, cudaMemcpyHostToDevice);
+ kernel<<<128, 1>>>(dev_x);
+ cudaMemcpy(x, dev_x, size, cudaMemcpyDeviceToHost);
+ cudaFree(dev_x);
+ }""")
+ def tearDown(self):
+ """Test Case Teardown"""
+ shutil.rmtree("temp")
+
+ def test_nvcc(self):
+ """Test Default NVCC"""
+ cwd = os.getcwd()
+ os.chdir("temp")
+ cc.nvcc("hello.cu")
+ self.assertTrue(os.path.isfile("a.out"))
+ os.chdir(cwd)
+
+ def test_nvcc_o(self):
+ """Test Named Output NVCC"""
+ cc.nvcc_o("temp/hello.cu", "temp/hello")
+ self.assertTrue(os.path.isfile("temp/hello"))
+
@unittest.skipUnless(xnt.in_path("javac"), "javac is not in your path")
class JavacTests(unittest.TestCase):
"""Test Javac"""