2
0
Эх сурвалжийг харах

разобрался с группами и осями

p.kushnir 6 жил өмнө
parent
commit
31eed913ac

+ 186 - 12
src/main/java/in/ocsf/bee/freigeld/core/cl/Sample1.java

@@ -3,24 +3,115 @@ package in.ocsf.bee.freigeld.core.cl;
 import com.aparapi.Kernel;
 import com.aparapi.Range;
 import com.aparapi.device.Device;
+import com.aparapi.device.OpenCLDevice;
+import com.fasterxml.jackson.databind.ObjectMapper;
 
+import java.util.HashMap;
+import java.util.Map;
+import java.util.concurrent.atomic.AtomicInteger;
 import java.util.logging.Logger;
 
 public class Sample1 implements Runnable {
 
+    private final ObjectMapper objectMapper = new ObjectMapper();
+
     private final Logger log = Logger.getLogger(getClass().getName());
+    private final Integer TOTAL_MEM;
+    private OpenCLDevice device;
+
+    public Sample1() {
+        Device _device = Device.best();
+        if (_device instanceof OpenCLDevice)
+            this.device = (OpenCLDevice) _device;
+        else
+            throw new RuntimeException("no opencl device");
+
+        log.info(device.getName());
+        TOTAL_MEM = Long.valueOf(Math.min(128, device.getMaxMemAllocSize() / 1024 / 1024 / 2)).intValue() * 1024 * 1024;
+    }
 
     @Override
     public void run() {
         final int N = 4;
-        final int R = 2;
+        final int G = 32;
 
-        final long[][][] mem = new long[R][R][R];
+        final int pB = 4 * 1024;
+        final int pBtotal = 0;
+        final int pBfrom = 1;
+        final int pBto = 2;
+        final int pBpos = 3;
+        final int pBlen = 4;
 
-        Kernel kernel = new Kernel() {
+        final byte[] mm = new byte[TOTAL_MEM];
+        final long[] arg = new long[3];
+        final long[][] mem = new long[N][G];
+        final long[][] res = new long[2][16];
+        final AtomicInteger[] at = new AtomicInteger[]{new AtomicInteger(0), new AtomicInteger(0), new AtomicInteger(0)};
 
-            @Override
-            public void run() {
+        Kernel kernel = new DryRunnable() {
+
+            protected @PrivateMemorySpace(1)
+            boolean[] dryRun = new boolean[]{false};
+
+            protected @PrivateMemorySpace(pB)
+            byte[] membuf = new byte[pB];
+            protected @PrivateMemorySpace(5)
+            int[] bufposlimits = new int[5];
+
+            private void mem2buf() {
+                int p0 = bufposlimits[pBpos];
+                int len = bufposlimits[pBlen];
+                int p = p0;
+                while (p < p0 + len - 1) {
+                    membuf[p - p0] = mm[p];
+                    p++;
+                }
+            }
+
+            private void buf2mem() {
+                int p0 = bufposlimits[pBpos];
+                int len = bufposlimits[pBlen];
+                int p = p0;
+                while (p < p0 + len - 1) {
+                    mm[p] = membuf[p - p0];
+                    p++;
+                }
+            }
+
+            private void initbuf() {
+                int x = dryRun[0] ? 3 : getGlobalId(0);
+                int xN = dryRun[0] ? G : getGlobalSize(0);
+                int block = mm.length / xN;
+                bufposlimits[pBtotal] = block;
+                int from = x * block;
+                int to = (x + 1) * block - 1;
+                bufposlimits[pBfrom] = from;
+                bufposlimits[pBto] = to;
+                bufpos(0);
+            }
+
+            private void bufpos(int pos) {
+                if (pos >= bufposlimits[pBpos] && pos < (bufposlimits[pBpos] + bufposlimits[pBlen]))
+                    return;
+
+                bufposlimits[pBpos] = max(0, min(bufposlimits[pBtotal] - pB, pos));
+                bufposlimits[pBlen] = min(bufposlimits[pBtotal] - bufposlimits[pBpos], pB);
+                mem2buf();
+            }
+
+            private void memset(int pos, byte val) {
+                bufpos(pos);
+                int pos0 = bufposlimits[pBpos];
+                membuf[pos - pos0] = val;
+            }
+
+            private byte memget(int pos) {
+                bufpos(pos);
+                int pos0 = bufposlimits[pBpos];
+                return membuf[pos - pos0];
+            }
+
+            private void test() {
                 int x = getGlobalId(0);
                 int y = getGlobalId(1);
                 int z = getGlobalId(2);
@@ -33,16 +124,99 @@ public class Sample1 implements Runnable {
                 int gxN = getNumGroups(0);
                 int gyN = getNumGroups(1);
                 int gzN = getNumGroups(2);
-                int l = getLocalId();
-                int lN = getLocalSize();
+                int lx = getLocalId(0);
+                int ly = getLocalId(1);
+                int lz = getLocalId(2);
+                int lxN = getLocalSize(0);
+                int lyN = getLocalSize(1);
+                int lzN = getLocalSize(2);
                 int p = getPassId();
-                mem[x][y][z] = gxN * 100 + gyN * 10 + gzN;
+
+                mem[0][x] = (x + 1) * 1000000 + (gx + 1) * 1000 + (lx + 1);
+                mem[1][y] = (y + 1) * 1000000 + (gy + 1) * 1000 + (ly + 1);
+                mem[2][z] = (z + 1) * 1000000 + (gz + 1) * 1000 + (lz + 1);
+
+                res[0][0] = p;
+                res[0][1] = xN;
+                res[0][2] = yN;
+                res[0][3] = zN;
+                res[0][4] = gxN;
+                res[0][5] = gyN;
+                res[0][6] = gzN;
+                res[0][7] = lxN;
+                res[0][8] = lyN;
+                res[0][9] = lzN;
+
+                atomicInc(at[0]);
+                atomicInc(at[1]);
+                atomicInc(at[2]);
+                int block = mm.length / xN;
+                int from = x * block;
+                int to = (x + 1) * block - 1;
+                mem[3][x] = to;
+                int i = from;
+                while (i <= to) {
+                    mm[i] = (byte) (x + 1);
+                    i++;
+                }
+            }
+
+
+            @Override
+            public void run() {
+                initbuf();
+                memset(0, (byte) 1);
+                byte val = memget(0);
+                if (val == 1) memset(1, (byte) 1);
+                buf2mem();
+            }
+
+            public void dryRun() {
+                dryRun[0] = true;
+                run();
+                dryRun[0] = false;
             }
         };
 
-        Device device = Device.best();
-        Range range = Range.create3D(device, R, R, R);
-        kernel.execute(range, 1);
-        log.info("done");
+        Range range = Range.create(device, G);
+        /*switch (device.getMaxWorkItemDimensions()){
+            case 1:
+                range = Range.create(device, G);
+                break;
+            case 2:
+                range = Range.create2D(device, G, G);
+                break;
+            case 3:
+                range = Range.create3D(device, G, G, G);
+                break;
+            default:
+                throw new RuntimeException("not supported");
+        }*/
+        try {
+            ((DryRunnable) kernel).dryRun();
+            kernel.execute(range, 1);
+        } catch (Exception e) {
+            throw new RuntimeException("dry run failed", e);
+        }
+        Map<String, Object> info = new HashMap<>();
+        info.put("time", kernel.getAccumulatedExecutionTime() / 1000.0);
+        info.put("memg", TOTAL_MEM / 1024.0 / 1024.0);
+        info.put("meml", device.getLocalMemSize() / 1024.0);
+        try {
+            log.info(objectMapper.writer().withDefaultPrettyPrinter().writeValueAsString(info));
+            log.info(objectMapper.writeValueAsString(mem[0]));
+            log.info(objectMapper.writeValueAsString(mem[1]));
+            log.info(objectMapper.writeValueAsString(mem[2]));
+            log.info(objectMapper.writeValueAsString(res[0]));
+            log.info(objectMapper.writeValueAsString(res[1]));
+            log.info(objectMapper.writeValueAsString(at));
+        } catch (Exception e) {
+            throw new RuntimeException(e);
+        }
+        kernel.dispose();
+    }
+
+    abstract class DryRunnable extends Kernel {
+        abstract void dryRun();
     }
 }

+ 16 - 0
src/main/kotlin/in/ocsf/bee/freigeld/core/FreiApp.kt

@@ -4,13 +4,19 @@ import `in`.ocsf.bee.freigeld.core.cl.Sample0
 import `in`.ocsf.bee.freigeld.core.cl.Sample1
 import org.springframework.boot.SpringApplication
 import org.springframework.boot.autoconfigure.SpringBootApplication
+import org.springframework.context.annotation.Profile
 import org.springframework.context.annotation.PropertySource
+import org.springframework.scheduling.annotation.EnableScheduling
+import org.springframework.scheduling.annotation.Scheduled
+import org.springframework.stereotype.Service
 import org.springframework.web.bind.annotation.RequestMapping
 import org.springframework.web.bind.annotation.RestController
+import org.springframework.web.client.RestTemplate
 import sun.misc.Unsafe
 import javax.annotation.PostConstruct
 
 @SpringBootApplication
+@EnableScheduling
 @PropertySource("classpath:application.yaml")
 class FreiApp{
 
@@ -34,6 +40,16 @@ class DevController{
     }
 }
 
+@Service
+@Profile("dev")
+class DevService {
+
+    @Scheduled(fixedRate = 10000L, initialDelay = 5000L)
+    fun init() {
+        RestTemplate().getForObject("http://127.0.0.1:4200/frei/dev/test1", String::class.java)
+    }
+}
+
 fun main(args: Array<String>) {
     disableWarning()
     SpringApplication.run(FreiApp::class.java, *args)