ggerganov commited on
Commit
b42b45f
·
unverified ·
1 Parent(s): 66cb760

metal : fix asserts for setThreadgroupMemoryLength (close #1435)

Browse files
Files changed (1) hide show
  1. ggml-metal.m +3 -3
ggml-metal.m CHANGED
@@ -1030,7 +1030,7 @@ void ggml_metal_graph_compute(
1030
  [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
1031
  [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
1032
  [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
1033
- [encoder setThreadgroupMemoryLength:nth/32*sizeof(float) atIndex:0];
1034
 
1035
  [encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
1036
  } break;
@@ -1342,7 +1342,7 @@ void ggml_metal_graph_compute(
1342
  [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
1343
  [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
1344
  [encoder setBytes:&eps length:sizeof( float) atIndex:4];
1345
- [encoder setThreadgroupMemoryLength:nth/32*sizeof(float) atIndex:0];
1346
 
1347
  const int64_t nrows = ggml_nrows(src0);
1348
 
@@ -1361,7 +1361,7 @@ void ggml_metal_graph_compute(
1361
  [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
1362
  [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
1363
  [encoder setBytes:&eps length:sizeof( float) atIndex:4];
1364
- [encoder setThreadgroupMemoryLength:nth*sizeof(float) atIndex:0];
1365
 
1366
  const int64_t nrows = ggml_nrows(src0);
1367
 
 
1030
  [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
1031
  [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
1032
  [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
1033
+ [encoder setThreadgroupMemoryLength:GGML_PAD(nth/32*sizeof(float), 16) atIndex:0];
1034
 
1035
  [encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
1036
  } break;
 
1342
  [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
1343
  [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
1344
  [encoder setBytes:&eps length:sizeof( float) atIndex:4];
1345
+ [encoder setThreadgroupMemoryLength:GGML_PAD(nth/32*sizeof(float), 16) atIndex:0];
1346
 
1347
  const int64_t nrows = ggml_nrows(src0);
1348
 
 
1361
  [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
1362
  [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
1363
  [encoder setBytes:&eps length:sizeof( float) atIndex:4];
1364
+ [encoder setThreadgroupMemoryLength:GGML_PAD(nth*sizeof(float), 16) atIndex:0];
1365
 
1366
  const int64_t nrows = ggml_nrows(src0);
1367