Skip to content

Commit 23635c5

Browse files
Minor fixes
1 parent 3d5c340 commit 23635c5

File tree

2 files changed

+125
-2
lines changed

2 files changed

+125
-2
lines changed

src/main/java/org/beehive/gpullama3/model/loader/ModelLoader.java

Lines changed: 118 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,17 @@
1515
import uk.ac.manchester.tornado.api.types.arrays.*;
1616

1717
import java.io.IOException;
18+
import java.lang.foreign.MemorySegment;
19+
import java.lang.foreign.ValueLayout;
1820
import java.nio.ByteOrder;
1921
import java.nio.FloatBuffer;
2022
import java.nio.channels.FileChannel;
2123
import java.nio.file.Path;
2224
import java.nio.file.StandardOpenOption;
2325
import java.util.Map;
26+
import java.util.Set;
2427
import java.util.function.IntFunction;
28+
import java.util.stream.Collectors;
2529

2630
public abstract class ModelLoader {
2731

@@ -88,7 +92,120 @@ public static Model loadModel(Options options) throws IOException {
8892
// detect model type
8993
ModelType modelType = detectModelType(gguf.getMetadata());
9094
// model type-specific load
91-
return modelType.loadModel(gguf.getFileChannel(), gguf, contextLength, loadWeights, useTornadovm);
95+
return modelType.loadModel(gguf.getFileChannel(), gguf, contextLength, useTornadovm);
96+
}
97+
98+
private static void compareTensorEntries(Map<String, GGMLTensorEntry> tensorEntries1, Map<String, GGMLTensorEntry> tensorEntries2) {
99+
System.out.println("[COMPARISON] Starting tensor entries comparison...");
100+
101+
// Check if both maps have the same keys
102+
Set<String> keys1 = tensorEntries1.keySet();
103+
Set<String> keys2 = tensorEntries2.keySet();
104+
105+
if (!keys1.equals(keys2)) {
106+
System.err.println("[ERROR] Tensor entry key sets don't match!");
107+
System.err.println("Keys in tensorEntries1 only: " +
108+
keys1.stream().filter(k -> !keys2.contains(k)).collect(Collectors.toSet()));
109+
System.err.println("Keys in tensorEntries2 only: " +
110+
keys2.stream().filter(k -> !keys1.contains(k)).collect(Collectors.toSet()));
111+
return;
112+
}
113+
114+
int totalTensors = keys1.size();
115+
int matchingTensors = 0;
116+
int errors = 0;
117+
118+
for (String tensorName : keys1) {
119+
GGMLTensorEntry entry1 = tensorEntries1.get(tensorName);
120+
GGMLTensorEntry entry2 = tensorEntries2.get(tensorName);
121+
122+
if (entry1 == null || entry2 == null) {
123+
System.err.println("[ERROR] Missing tensor entry for: " + tensorName);
124+
errors++;
125+
continue;
126+
}
127+
128+
try {
129+
boolean isMatch = compareSingleTensor(tensorName, entry1, entry2);
130+
if (isMatch) {
131+
matchingTensors++;
132+
System.out.println("[OK] " + tensorName + " - tensors match");
133+
} else {
134+
errors++;
135+
System.err.println("[MISMATCH] " + tensorName + " - tensors don't match");
136+
}
137+
} catch (Exception e) {
138+
errors++;
139+
System.err.println("[ERROR] Exception comparing " + tensorName + ": " + e.getMessage());
140+
}
141+
}
142+
143+
System.out.println("\n[COMPARISON SUMMARY]");
144+
System.out.println("Total tensors: " + totalTensors);
145+
System.out.println("Matching tensors: " + matchingTensors);
146+
System.out.println("Errors/Mismatches: " + errors);
147+
System.out.println("Success rate: " + String.format("%.1f%%", (matchingTensors * 100.0) / totalTensors));
148+
}
149+
150+
private static boolean compareSingleTensor(String tensorName, GGMLTensorEntry entry1, GGMLTensorEntry entry2) {
151+
// Get memory segments
152+
MemorySegment segment1 = entry1.memorySegment();
153+
MemorySegment segment2 = entry2.memorySegment();
154+
155+
// Special case: token_embd.weight and rope_freqs.weight should be identical
156+
boolean isSpecialCase = tensorName.equals("token_embd.weight") || tensorName.equals("rope_freqs.weight");
157+
158+
if (isSpecialCase) {
159+
// For these tensors, the segments should be identical
160+
if (segment1.byteSize() != segment2.byteSize()) {
161+
System.err.println(" Size mismatch for " + tensorName + ": " +
162+
segment1.byteSize() + " vs " + segment2.byteSize());
163+
return false;
164+
}
165+
166+
// Compare byte by byte
167+
for (long i = 0; i < segment1.byteSize(); i++) {
168+
byte b1 = segment1.get(ValueLayout.JAVA_BYTE, i);
169+
byte b2 = segment2.get(ValueLayout.JAVA_BYTE, i);
170+
if (b1 != b2) {
171+
System.err.println(" Byte mismatch at offset " + i + " for " + tensorName +
172+
": " + String.format("0x%02X", b1) + " vs " + String.format("0x%02X", b2));
173+
return false;
174+
}
175+
}
176+
return true;
177+
}
178+
179+
// For regular tensors, segment2 should have 16-byte header + segment1 data
180+
long expectedSize2 = segment1.byteSize() + 16;
181+
if (segment2.byteSize() != expectedSize2) {
182+
System.err.println(" Size mismatch for " + tensorName + ": expected " +
183+
expectedSize2 + " (16 + " + segment1.byteSize() + "), got " + segment2.byteSize());
184+
return false;
185+
}
186+
187+
// Check that first 16 bytes of segment2 are zeros (header)
188+
for (long i = 0; i < 16; i++) {
189+
byte headerByte = segment2.get(ValueLayout.JAVA_BYTE, i);
190+
if (headerByte != 0) {
191+
System.err.println(" Non-zero header byte at offset " + i + " for " + tensorName +
192+
": " + String.format("0x%02X", headerByte));
193+
return false;
194+
}
195+
}
196+
197+
// Compare the actual tensor data (starting at offset 16 in segment2)
198+
for (long i = 0; i < segment1.byteSize(); i++) {
199+
byte b1 = segment1.get(ValueLayout.JAVA_BYTE, i);
200+
byte b2 = segment2.get(ValueLayout.JAVA_BYTE, i + 16); // +16 to skip header
201+
if (b1 != b2) {
202+
System.err.println(" Data mismatch at offset " + i + " for " + tensorName +
203+
": " + String.format("0x%02X", b1) + " vs " + String.format("0x%02X", b2));
204+
return false;
205+
}
206+
}
207+
208+
return true;
92209
}
93210

94211
/**

src/main/java/org/beehive/gpullama3/tensor/GGUF.java

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,13 @@
22

33
import org.beehive.gpullama3.tensor.standard.FloatTensor;
44
import org.beehive.gpullama3.auxiliary.Pair;
5+
import uk.ac.manchester.tornado.api.types.arrays.TornadoNativeArray;
56

67
import java.io.FileNotFoundException;
78
import java.io.IOException;
89
import java.lang.foreign.Arena;
910
import java.lang.foreign.MemorySegment;
11+
import java.lang.foreign.ValueLayout;
1012
import java.nio.ByteBuffer;
1113
import java.nio.ByteOrder;
1214
import java.nio.channels.FileChannel;
@@ -62,7 +64,7 @@ public static GGUF loadGGUFMetadata(Path modelPath) throws IOException {
6264
gguf.readHeader(fileChannel); // gguf_header_t header;
6365
// Tensor infos, which can be used to locate the tensor data.
6466
// gguf_tensor_info_t tensor_infos[header.tensor_count];
65-
this.tensorInfos = HashMap.newHashMap(gguf.tensorCount);
67+
gguf.tensorInfos = HashMap.newHashMap(gguf.tensorCount);
6668
for (int i = 0; i < gguf.tensorCount; ++i) {
6769
GGUF.GGUFTensorInfo ti = gguf.readTensorInfo(fileChannel);
6870
assert !gguf.tensorInfos.containsKey(ti.name);
@@ -204,6 +206,10 @@ public Map<String, Object> getMetadata() {
204206
return metadata;
205207
}
206208

209+
public FileChannel getFileChannel() {
210+
return fileChannel;
211+
}
212+
207213
private GGMLType readGGMLType(FileChannel fileChannel) throws IOException {
208214
int ggmlTypeId = readInt(fileChannel); // ggml_type type;
209215
return GGMLType.fromId(ggmlTypeId);

0 commit comments

Comments
 (0)