summaryrefslogtreecommitdiff
blob: 8e6753781a2a1921893a8161919eb929e073201d (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
https://github.com/ROCmSoftwarePlatform/Tensile/issues/1395
https://github.com/ROCmSoftwarePlatform/Tensile/pull/1398

--- a/Tensile/TensileCreateLibrary.py
+++ b/Tensile/TensileCreateLibrary.py
@@ -136,6 +136,35 @@ def which(p):
                 return candidate
     return None
 
+def splitArchs():
+  # Helper for architecture
+  def isSupported(arch):
+    return globalParameters["AsmCaps"][arch]["SupportedISA"] and \
+           globalParameters["AsmCaps"][arch]["SupportedSource"]
+
+  if ";" in globalParameters["Architecture"]:
+    wantedArchs = globalParameters["Architecture"].split(";")
+  else:
+    wantedArchs = globalParameters["Architecture"].split("_")
+  archs = []
+  cmdlineArchs = []
+  if "all" in wantedArchs:
+    for arch in globalParameters['SupportedISA']:
+      if isSupported(arch):
+        if (arch == (9,0,6) or arch == (9,0,8) or arch == (9,0,10)):
+          if (arch == (9,0,10)):
+            archs += [gfxName(arch) + '-xnack+']
+            cmdlineArchs += [gfxName(arch) + ':xnack+']
+          archs += [gfxName(arch) + '-xnack-']
+          cmdlineArchs += [gfxName(arch) + ':xnack-']
+        else:
+          archs += [gfxName(arch)]
+          cmdlineArchs += [gfxName(arch)]
+  else:
+    for arch in wantedArchs:
+      archs += [re.sub(":", "-", arch)]
+      cmdlineArchs += [arch]
+  return archs, cmdlineArchs
 
 def buildSourceCodeObjectFile(CxxCompiler, outputPath, kernelFile):
     buildPath = ensurePath(os.path.join(globalParameters['WorkingPath'], 'code_object_tmp'))
@@ -149,24 +178,8 @@ def buildSourceCodeObjectFile(CxxCompiler, outputPath, kernelFile):
     objectFilename = base + '.o'
     soFilename = base + '.so'
 
-    def isSupported(arch):
-        return globalParameters["AsmCaps"][arch]["SupportedISA"] and \
-               globalParameters["AsmCaps"][arch]["SupportedSource"]
-
     if (CxxCompiler == "hipcc"):
-      archs = []
-      cmdlineArchs = []
-      for arch in globalParameters['SupportedISA']:
-        if isSupported(arch):
-          if (arch == (9,0,6) or arch == (9,0,8) or arch == (9,0,10)):
-            if (arch == (9,0,10)):
-              archs += [gfxName(arch) + '-xnack+']
-              cmdlineArchs += [gfxName(arch) + ':xnack+']
-            archs += [gfxName(arch) + '-xnack-']
-            cmdlineArchs += [gfxName(arch) + ':xnack-']
-          else:
-            archs += [gfxName(arch)]
-            cmdlineArchs += [gfxName(arch)]
+      archs, cmdlineArchs = splitArchs()
 
       archFlags = ['--offload-arch=' + arch for arch in cmdlineArchs]
 
@@ -1063,11 +1076,6 @@ def buildObjectFileNames(solutionWriter, kernelWriterSource, kernelWriterAssembl
   sourceKernels = list([k for k in kernels if k['KernelLanguage'] == 'Source'])
   asmKernels = list([k for k in kernels if k['KernelLanguage'] == 'Assembly'])
 
-  # Helper for architecture
-  def isSupported(arch):
-        return globalParameters["AsmCaps"][arch]["SupportedISA"] and \
-               globalParameters["AsmCaps"][arch]["SupportedSource"]
-
   # Build a list of kernel object names.
   for kernel in sourceKernels:
     sourceKernelNames += [kernelWriterSource.getKernelFileBase(kernel)]
@@ -1081,15 +1089,7 @@ def buildObjectFileNames(solutionWriter, kernelWriterSource, kernelWriterAssembl
 
   # Source based kernels are built for all supported architectures
   if (cxxCompiler == 'hipcc'):
-    sourceArchs = []
-    for arch in globalParameters['SupportedISA']:
-      if isSupported(arch):
-        if (arch == (9,0,6) or arch == (9,0,8) or arch == (9,0,10)):
-          if (arch == (9,0,10)):
-            sourceArchs += [gfxName(arch) + '-xnack+']
-          sourceArchs += [gfxName(arch) + '-xnack-']
-        else:
-          sourceArchs += [gfxName(arch)]
+    sourceArchs, _ = splitArchs()
   else:
     raise RuntimeError("Unknown compiler %s" % cxxCompiler)