-
Notifications
You must be signed in to change notification settings - Fork 24
[Opt] Manipulation of Q8_0 tensors with Tornado ByteArrays
#79
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
e6ce9e0
db30dba
553015d
da10c5c
579d6ea
7adfd80
613cdd2
91f48b0
786bdc2
78d6a18
7456d59
04db93d
dd8064e
2316ca1
56a960a
9d0fb16
68729ee
843e30c
111dbdd
4e984fa
ed5f882
4e30022
c52bcaa
9562505
2c8cf24
eccdce6
4f13785
572f7b3
a920424
6b66a59
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2,8 +2,8 @@ | |
|
|
||
| import org.beehive.gpullama3.tensor.standard.FloatTensor; | ||
| import org.beehive.gpullama3.model.Configuration; | ||
| import uk.ac.manchester.tornado.api.types.arrays.FloatArray; | ||
| import uk.ac.manchester.tornado.api.types.arrays.IntArray; | ||
| import uk.ac.manchester.tornado.api.types.HalfFloat; | ||
| import uk.ac.manchester.tornado.api.types.arrays.*; | ||
|
|
||
| /** | ||
| * Represents the base state structure used during LLM inference. | ||
|
|
@@ -57,6 +57,7 @@ public abstract class State { | |
| public final FloatArray wrapValueCache; // FloatArray wrapper for the value cache, optimized for TornadoVM. | ||
| public final IntArray positionHolder; | ||
|
|
||
orionpapadakis marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| public TornadoNativeArray embeddingX; | ||
|
||
| // store inter | ||
| public int localSize; | ||
| public FloatArray temp; // Temporary buffer for intermediate calculations, size adjusted for local workgroup size. | ||
|
|
@@ -88,6 +89,7 @@ protected State(Configuration config, int batchsize) { | |
| this.keyCache = fields.keyCache; | ||
| this.valueCache = fields.valueCache; | ||
|
|
||
| this.embeddingX = fields.embeddingX; | ||
| this.wrapX = fields.wrapX; | ||
| this.wrapXb = fields.wrapXb; | ||
| this.wrapXb2 = fields.wrapXb2; | ||
|
|
@@ -121,6 +123,19 @@ protected static class StateFields { | |
| public FloatArray wrapQ, wrapK, wrapV, wrapAtt, wrapKeyCache, wrapValueCache; | ||
| public IntArray positionHolder; | ||
| public FloatArray temp, tempFFN, tempLogits; | ||
| public TornadoNativeArray embeddingX; | ||
|
|
||
| public void createActivationFP16(int size) { | ||
orionpapadakis marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| this.embeddingX = new HalfFloatArray(size); | ||
| } | ||
|
|
||
| public void createActivationQ8_0(int size) { | ||
| int blockSize = 32; | ||
| int Q8_0_BLOCK_BYTES = 34; // 2 bytes scale + 32 bytes quants | ||
orionpapadakis marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| int blocksNeeded = (size + blockSize - 1) / blockSize; | ||
| int q8BytesNeeded = blocksNeeded * Q8_0_BLOCK_BYTES; | ||
| this.embeddingX = new ByteArray(q8BytesNeeded); | ||
| } | ||
orionpapadakis marked this conversation as resolved.
Show resolved
Hide resolved
orionpapadakis marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| } | ||
|
|
||
| @Override | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -35,6 +35,15 @@ protected AbstractModelLoader(FileChannel fileChannel, GGUF gguf, int contextLen | |
| this.useTornadovm = useTornadovm; | ||
| } | ||
|
|
||
| protected String getModelQuantization(Map<String, Object> metadata) { | ||
| int modelQuantizationAsInt = (int) metadata.get("general.file_type"); | ||
| return switch (modelQuantizationAsInt) { | ||
| case 1 -> "FP16"; | ||
| case 7 -> "Q8_0"; | ||
|
Comment on lines
+40
to
+42
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what are these magic numbers 1 & 7? |
||
| default -> throw new UnsupportedOperationException("Unsupported quantization format: " + modelQuantizationAsInt + " (as int)."); | ||
| }; | ||
| } | ||
|
|
||
| /** | ||
| * Template method that defines the model loading workflow. Subclasses should not override this method. | ||
| * | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.