Skip to content

Commit 11562df

Browse files
Formatting
1 parent 8fb6cd1 commit 11562df

File tree

10 files changed

+41
-55
lines changed

10 files changed

+41
-55
lines changed

src/main/java/org/beehive/gpullama3/model/ModelType.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
* <p><b>Usage:</b> Use {@code ModelType} to specify or retrieve the type of
1717
* large language model (LLM), such as Llama or Qwen3. This ensures clean and structured handling of model behaviors and configurations by
1818
* dispatching calls to the appropriate model loader for each
19-
* model type.</p>
19+
* model type.</p>
2020
*
2121
* <p>Each enum value represents a distinct model type, which might be used for
2222
* conditional logic, initialization, or resource allocation within GPULlama3.java.</p>

src/main/java/org/beehive/gpullama3/model/loader/AbstractModelLoader.java

Lines changed: 13 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,8 @@
1616
/**
1717
* Abstract base class for model loaders using Template Method pattern. Provides common loading flow with extension points for model-specific logic.
1818
*
19-
* @param <M>
20-
* The specific Model type to load
21-
* @param <C>
22-
* The specific Configuration type for the model
19+
* @param <M> The specific Model type to load
20+
* @param <C> The specific Configuration type for the model
2321
*/
2422
public abstract class AbstractModelLoader<M extends Model, C extends Configuration> {
2523

@@ -77,39 +75,33 @@ public final M loadModel() {
7775
/**
7876
* Load the vocabulary from GGUF metadata. Model-specific implementations should override this method.
7977
*
80-
* @param metadata
81-
* The GGUF metadata map
78+
* @param metadata The GGUF metadata map
8279
* @return The loaded Vocabulary
8380
*/
8481
protected abstract Vocabulary loadVocabulary(Map<String, Object> metadata);
8582

8683
/**
8784
* Create a tokenizer instance for this model.
8885
*
89-
* @param metadata
90-
* The GGUF metadata map
91-
* @param vocabulary
92-
* The loaded vocabulary
86+
* @param metadata The GGUF metadata map
87+
* @param vocabulary The loaded vocabulary
9388
* @return The tokenizer instance
9489
*/
9590
protected abstract Tokenizer createTokenizer(Map<String, Object> metadata, Vocabulary vocabulary);
9691

9792
/**
9893
* Create a configuration instance from GGUF metadata.
9994
*
100-
* @param metadata
101-
* The GGUF metadata map
95+
* @param metadata The GGUF metadata map
10296
* @return The configuration instance
10397
*/
10498
protected abstract C createConfiguration(Map<String, Object> metadata);
10599

106100
/**
107101
* Load model weights from tensor entries. Default implementation handles common weight loading logic.
108102
*
109-
* @param tensorEntries
110-
* Map of tensor names to tensor entries
111-
* @param config
112-
* The model configuration
103+
* @param tensorEntries Map of tensor names to tensor entries
104+
* @param config The model configuration
113105
* @return The loaded weights
114106
*/
115107
public Weights loadWeights(Map<String, GGMLTensorEntry> tensorEntries, C config) {
@@ -131,12 +123,9 @@ public Weights loadWeights(Map<String, GGMLTensorEntry> tensorEntries, C config)
131123
/**
132124
* Create the final model instance.
133125
*
134-
* @param config
135-
* The model configuration
136-
* @param tokenizer
137-
* The tokenizer
138-
* @param weights
139-
* The loaded weights
126+
* @param config The model configuration
127+
* @param tokenizer The tokenizer
128+
* @param weights The loaded weights
140129
* @return The model instance
141130
*/
142131
protected abstract M createModel(C config, Tokenizer tokenizer, Weights weights);
@@ -164,11 +153,11 @@ protected GGMLTensorEntry getOutputWeight(Map<String, GGMLTensorEntry> tensorEnt
164153
* Create standard (CPU) weights.
165154
*/
166155
protected abstract Weights createStandardWeights(Map<String, GGMLTensorEntry> tensorEntries, C config, Pair<float[], float[]> ropeFreqs, GGMLTensorEntry tokenEmbeddings,
167-
GGMLTensorEntry outputWeight);
156+
GGMLTensorEntry outputWeight);
168157

169158
/**
170159
* Create TornadoVM (GPU) weights.
171160
*/
172161
protected abstract Weights createTornadoVMWeights(Map<String, GGMLTensorEntry> tensorEntries, C config, Pair<float[], float[]> ropeFreqs, GGMLTensorEntry tokenEmbeddings,
173-
GGMLTensorEntry outputWeight);
162+
GGMLTensorEntry outputWeight);
174163
}

src/main/java/org/beehive/gpullama3/model/loader/LlamaModelLoader.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ protected Llama createModel(LlamaConfiguration config, Tokenizer tokenizer, Weig
7373

7474
@Override
7575
protected Weights createStandardWeights(Map<String, GGMLTensorEntry> tensorEntries, LlamaConfiguration config, Pair<float[], float[]> ropeFreqs, GGMLTensorEntry tokenEmbeddings,
76-
GGMLTensorEntry outputWeight) {
76+
GGMLTensorEntry outputWeight) {
7777

7878
final int nl = config.numberOfLayers();
7979

src/main/java/org/beehive/gpullama3/model/loader/MistralModelLoader.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ protected Mistral createModel(MistralConfiguration config, Tokenizer tokenizer,
7070

7171
@Override
7272
protected Weights createStandardWeights(Map<String, GGMLTensorEntry> tensorEntries, MistralConfiguration config, Pair<float[], float[]> ropeFreqs, GGMLTensorEntry tokenEmbeddings,
73-
GGMLTensorEntry outputWeight) {
73+
GGMLTensorEntry outputWeight) {
7474

7575
final int nl = config.numberOfLayers();
7676

src/main/java/org/beehive/gpullama3/model/loader/ModelLoader.java

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
import java.nio.FloatBuffer;
2222
import java.nio.channels.FileChannel;
2323
import java.nio.file.Path;
24-
import java.nio.file.StandardOpenOption;
2524
import java.util.Map;
2625
import java.util.Set;
2726
import java.util.function.IntFunction;
@@ -74,13 +73,10 @@ private static ModelType detectModelType(Map<String, Object> metadata) {
7473
* If Ahead-of-Time (AOT) mode is enabled, attempts to use a pre-loaded compiled model. Otherwise, loads the model from the specified path using the model loader.
7574
* </p>
7675
*
77-
* @param options
78-
* the parsed CLI options containing model path and max token limit
76+
* @param options the parsed CLI options containing model path and max token limit
7977
* @return the loaded {@link Model} instance
80-
* @throws IOException
81-
* if the model fails to load
82-
* @throws IllegalStateException
83-
* if AOT loading is enabled but the preloaded model is unavailable
78+
* @throws IOException if the model fails to load
79+
* @throws IllegalStateException if AOT loading is enabled but the preloaded model is unavailable
8480
*/
8581
public static Model loadModel(Options options) throws IOException {
8682
Path ggufPath = options.modelPath();
@@ -163,7 +159,7 @@ public static TornadoTensor loadTornadoTensorAsFP32(GGMLTensorEntry entry) {
163159
HalfFloatArray tensorHFA = tensor.asHalfFloatArray();
164160
int numOfElements = tensorHFA.getSize();
165161
FloatArray tensorFA = new FloatArray(numOfElements);
166-
for(int i = 0; i < numOfElements; i++) {
162+
for (int i = 0; i < numOfElements; i++) {
167163
tensorFA.set(i, tensorHFA.get(i).getFloat32());
168164
}
169165
yield new FP32TornadoTensor(tensorFA);
@@ -172,13 +168,15 @@ public static TornadoTensor loadTornadoTensorAsFP32(GGMLTensorEntry entry) {
172168
Q8_0TornadoTensor tensorQ8_0 = Q8_0TornadoTensor.create(entry);
173169
int numOfElements = tensorQ8_0.getSize();
174170
FloatArray tensorFA = new FloatArray(numOfElements);
175-
for(int i = 0; i < numOfElements; i++) {
171+
for (int i = 0; i < numOfElements; i++) {
176172
tensorFA.set(i, tensorQ8_0.getFloat(i));
177173
}
178174
yield new FP32TornadoTensor(tensorFA);
179175

180176
}
181-
default -> { throw new UnsupportedOperationException("Unsupported tensor type: " + tensor.type()); }
177+
default -> {
178+
throw new UnsupportedOperationException("Unsupported tensor type: " + tensor.type());
179+
}
182180
};
183181
}
184182

src/main/java/org/beehive/gpullama3/model/loader/Phi3ModelLoader.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -114,9 +114,9 @@ protected Weights createStandardWeights(Map<String, GGMLTensorEntry> tensorEntri
114114

115115
@Override
116116
protected Weights createTornadoVMWeights(Map<String, GGMLTensorEntry> tensorEntries, Phi3Configuration config, Pair<float[], float[]> ropeFreqs, GGMLTensorEntry tokenEmbeddings,
117-
GGMLTensorEntry outputWeight) {
117+
GGMLTensorEntry outputWeight) {
118118
GGMLType ggmlType = outputWeight.ggmlType();
119-
119+
120120
if (TornadoVMMasterPlan.ENABLE_TORNADOVM_INIT_TIME) {
121121
System.out.println("Loading model weights in TornadoVM format (loading " + ggmlType + ")");
122122
}

src/main/java/org/beehive/gpullama3/model/loader/Qwen2ModelLoader.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ protected Qwen2 createModel(Qwen2Configuration config, Tokenizer tokenizer, Weig
8686

8787
@Override
8888
protected Weights createStandardWeights(Map<String, GGMLTensorEntry> tensorEntries, Qwen2Configuration config, Pair<float[], float[]> ropeFreqs, GGMLTensorEntry tokenEmbeddings,
89-
GGMLTensorEntry outputWeight) {
89+
GGMLTensorEntry outputWeight) {
9090

9191
final int nl = config.numberOfLayers();
9292

@@ -114,7 +114,7 @@ protected Weights createStandardWeights(Map<String, GGMLTensorEntry> tensorEntri
114114

115115
@Override
116116
protected Weights createTornadoVMWeights(Map<String, GGMLTensorEntry> tensorEntries, Qwen2Configuration config, Pair<float[], float[]> ropeFreqs, GGMLTensorEntry tokenEmbeddings,
117-
GGMLTensorEntry outputWeight) {
117+
GGMLTensorEntry outputWeight) {
118118
GGMLType ggmlType = outputWeight.ggmlType();
119119

120120
if (TornadoVMMasterPlan.ENABLE_TORNADOVM_INIT_TIME) {

src/main/java/org/beehive/gpullama3/model/loader/Qwen3ModelLoader.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ protected Qwen3 createModel(Qwen3Configuration config, Tokenizer tokenizer, Weig
8888

8989
@Override
9090
protected Weights createStandardWeights(Map<String, GGMLTensorEntry> tensorEntries, Qwen3Configuration config, Pair<float[], float[]> ropeFreqs, GGMLTensorEntry tokenEmbeddings,
91-
GGMLTensorEntry outputWeight) {
91+
GGMLTensorEntry outputWeight) {
9292
float[] ropeFreqsReal = ropeFreqs.first();
9393
float[] ropeFreqsImag = ropeFreqs.second();
9494

src/main/java/org/beehive/gpullama3/tensor/GGUF.java

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -96,13 +96,11 @@ public static GGUF loadGGUFMetadata(Path modelPath) throws IOException {
9696
* Loads tensor data from a given file channel based on the tensor metadata information.
9797
* The mapping is read-only and creates standard memory segments for each tensor.
9898
*
99-
* @param fileChannel the channel from which tensor storage is read
100-
* @param tensorDataOffset the absolute byte offset of the GGUF tensor-data section
101-
* @param tensorInfos metadata describing all GGUF tensors
102-
*
99+
* @param fileChannel the channel from which tensor storage is read
100+
* @param tensorDataOffset the absolute byte offset of the GGUF tensor-data section
101+
* @param tensorInfos metadata describing all GGUF tensors
103102
* @return a map from tensor name to {@link GGMLTensorEntry} containing
104-
* standard memory segments for each tensor
105-
*
103+
* standard memory segments for each tensor
106104
* @throws IOException if memory mapping fails or the channel cannot be read
107105
*/
108106
public static Map<String, GGMLTensorEntry> loadTensorsStandard(FileChannel fileChannel, long tensorDataOffset, Map<String, GGUFTensorInfo> tensorInfos) throws IOException {
@@ -152,14 +150,11 @@ public static Map<String, GGMLTensorEntry> loadTensorsStandard(FileChannel fileC
152150
* before the actual tensor position, providing a writable header region
153151
* without modifying the underlying GGUF file.</p>
154152
*
155-
*
156-
* @param fileChannel the channel from which tensor storage is read
157-
* @param tensorDataOffset the absolute byte offset of the GGUF tensor-data section
158-
* @param tensorInfos metadata describing all GGUF tensors
159-
*
153+
* @param fileChannel the channel from which tensor storage is read
154+
* @param tensorDataOffset the absolute byte offset of the GGUF tensor-data section
155+
* @param tensorInfos metadata describing all GGUF tensors
160156
* @return a map from tensor name to {@link GGMLTensorEntry} containing
161-
* TornadoVM-compatible memory segments for each tensor
162-
*
157+
* TornadoVM-compatible memory segments for each tensor
163158
* @throws IOException if memory mapping fails or the channel cannot be read
164159
*/
165160
public static Map<String, GGMLTensorEntry> loadTensorsTornado(FileChannel fileChannel, long tensorDataOffset, Map<String, GGUFTensorInfo> tensorInfos) throws IOException {

src/main/java/org/beehive/gpullama3/tensor/tornado/TornadoTensor.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ public abstract class TornadoTensor {
1515

1616
/**
1717
* Get as FloatArray (for F32 tensors).
18+
*
1819
* @throws UnsupportedOperationException if not F32
1920
*/
2021
public FloatArray asFloatArray() {
@@ -23,6 +24,7 @@ public FloatArray asFloatArray() {
2324

2425
/**
2526
* Get as HalfFloatArray (for F16 tensors).
27+
*
2628
* @throws UnsupportedOperationException if not F16
2729
*/
2830
public HalfFloatArray asHalfFloatArray() {
@@ -31,6 +33,7 @@ public HalfFloatArray asHalfFloatArray() {
3133

3234
/**
3335
* Get quantized scales (for Q8_0 tensors).
36+
*
3437
* @throws UnsupportedOperationException if not quantized
3538
*/
3639
public HalfFloatArray getScales() {
@@ -39,6 +42,7 @@ public HalfFloatArray getScales() {
3942

4043
/**
4144
* Get quantized values (for Q8_0 tensors).
45+
*
4246
* @throws UnsupportedOperationException if not quantized
4347
*/
4448
public Int8Array getQuants() {

0 commit comments

Comments
 (0)