From f10e7e3b1588fc3b2d29cfb3a6e9ff0afb0b685c Mon Sep 17 00:00:00 2001 From: kennyballou Date: Fri, 21 Jun 2013 17:05:57 -0600 Subject: Add NVCC compiler to xnt.build.cc --- xnt/build/cc.py | 21 ++++++++++++++++++- xnt/tests/compilercollectiontests.py | 39 ++++++++++++++++++++++++++++++++++++ 2 files changed, 59 insertions(+), 1 deletion(-) (limited to 'xnt') diff --git a/xnt/build/cc.py b/xnt/build/cc.py index 5cb455f..967608f 100644 --- a/xnt/build/cc.py +++ b/xnt/build/cc.py @@ -22,7 +22,7 @@ Definition of commonly used compilers import os import logging -from xnt.tasks import call +from xnt.tasks import call, which LOGGER = logging.getLogger(__name__) @@ -70,6 +70,25 @@ def javac(src, flags=None): cmd = __generate_command(src, flags, "javac") return __compile(cmd) +def nvcc(src, flags=None): + """NVCC: compile CUDA C/C++ programs + + :param src: CUDA source file to compile with default `nvcc` + :param flags: List of flags to pass onto the compiler + """ + assert which('nvcc') + return _gcc(src, flags, compiler='nvcc') + +def nvcc_o(src, output, flags=None): + """NVCC: compile with named output + + :param src: CUDA source file to compile with default `nvcc` + :param output: Name of resulting object or executable + :param flags: List of flags to pass onto the compiler + """ + assert which('nvcc') + return _gcc_o(src, output, flags, compiler='nvcc') + def _gcc(src, flags=None, compiler="gcc"): """Compile using gcc""" LOGGER.info("Compiling %s", src) diff --git a/xnt/tests/compilercollectiontests.py b/xnt/tests/compilercollectiontests.py index be72c81..cb5b3fc 100644 --- a/xnt/tests/compilercollectiontests.py +++ b/xnt/tests/compilercollectiontests.py @@ -88,6 +88,45 @@ class GppTests(unittest.TestCase): cc.gpp_o("temp/hello.cpp", "temp/hello") self.assertTrue(os.path.isfile("temp/hello")) +@unittest.skipUnless(xnt.in_path('nvcc'), 'nvcc is not in your path') +class NvccTests(unittest.TestCase): + """Test NVCC""" + def setUp(self): + """Test Case Setup""" + os.mkdir("temp") + with open("temp/hello.cu", "w") as test_code: + test_code.write(""" + __global__ void kernel(float *x) { + int idx = threadIdx.x; + x[idx] = 42; + } + int main() { + int size = sizeof(float) * 128; + float *x = (float*)malloc(size); + float *dev_x; + cudaMalloc((void**)&dev_x, size); + cudaMemcpy(dev_x, x, size, cudaMemcpyHostToDevice); + kernel<<<128, 1>>>(dev_x); + cudaMemcpy(x, dev_x, size, cudaMemcpyDeviceToHost); + cudaFree(dev_x); + }""") + def tearDown(self): + """Test Case Teardown""" + shutil.rmtree("temp") + + def test_nvcc(self): + """Test Default NVCC""" + cwd = os.getcwd() + os.chdir("temp") + cc.nvcc("hello.cu") + self.assertTrue(os.path.isfile("a.out")) + os.chdir(cwd) + + def test_nvcc_o(self): + """Test Named Output NVCC""" + cc.nvcc_o("temp/hello.cu", "temp/hello") + self.assertTrue(os.path.isfile("temp/hello")) + @unittest.skipUnless(xnt.in_path("javac"), "javac is not in your path") class JavacTests(unittest.TestCase): """Test Javac""" -- cgit v1.2.1