Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 54 additions & 0 deletions examples/llava/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,59 @@ python ./examples/convert-legacy-llama.py ../llava-v1.6-vicuna-7b/ --skip-unknow
**note** llava-1.6 needs more context than llava-1.5, at least 3000 is needed (just run it at -c 4096)
**note** llava-1.6 greatly benefits from batched prompt processing (defaults work)

## Phi-3-Vision-128K-Instruct gguf conversion
1) Set a working directory for PHI3V and PHI3 instruct. Clone both into this dir. (It's easiest to cd into your local hf cache and copy the models from there to here)

```console
mkdir phi3-fun
cd phi3-fun

mkdir phi3-base
git clone https://huggingface.co/microsoft/Phi-3-mini-128k-instruct

mkdir phi3-vision
git clone https://huggingface.co/microsoft/Phi-3-vision-128k-instruct

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The directories won't match up at this point as "git clone" creates their own subdirs.
renaming them would be better, and 2x mkdir could be removed

```

2) Use `llava-surgery-v2.py` to extract clip from PHI3V:
```console
python examples/llava/llava-surgery-v2.py -C -m phi3-fun/phi3-vision/
```
- you will find a llava.projector and a llava.clip file in your model directory

4) Copy the llava.clip file into a subdirectory (like vit), rename it to pytorch_model.bin and add a fitting vit configuration to the directory:
```console
// under phi3-fun/phi-vision dir
mkdir vit
cp llava.clip vit/pytorch_model.bin
cp llava.projector vit/
curl -s -q https://huggingface.co/cmp-nct/llava-1.6-gguf/raw/main/config_vit.json -o vit/config.json
```

5) Create the visual gguf model:
```console
python examples/llava/convert-image-encoder-to-gguf.py -m phi3-fun/phi3-vision/vit --llava-projector phi3-fun/phi3-vision/vit/llava.projector --output-dir phi3-fun/phi3-vision/vit --clip-model-is-vision
Copy link
Contributor

@cmp-nct cmp-nct Jun 14, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

--projector-type mlp_phi
I don't think that changing the config.json should still be required when specifying this, that would remove one necessary manual step from the list

```

6) Extract the language-modelling (everything except CLIP) part of PHI3V and assign the weights to a normal PHI3 model

```console
python examples/llava/phi3-weight-transfer.py --phi3-instruct-base-path phi3-fun/phi3-base --phi3v-base-path phi3-fun/phi3-vision
```

7) Convert this to a normal gguf
(First delete the old safetensors from this directory)
```console
python convert-hf-to-gguf.py phi3-fun/phi3-base
```

8) Invoke
(recompile llama.cpp first)
```console
./llava-cli -m phi3-fun/phi3-base/ggml-model-f16.gguf --mmproj phi3-fun/phi3-vision/vit/mmproj-model-f16.gguf --image IMAGE -c 4096 --temp .1 -p "PROMPT"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Templating should be recommended.
The below one should be correct for phi3v
<|user|>\n<image>\nPROMPT<|end|>\n<|assistant|>\n

```

## llava-cli templating and llava-1.6 prompting

llava-1.5 models all use the same vicuna prompt, here you can just add your image question like `-p "Provide a full description."`
Expand Down Expand Up @@ -137,3 +190,4 @@ Alternatively just pay notice to how many "tokens" have been used for your promp
- [x] Support non-CPU backend for the image encoding part.
- [ ] Support different sampling methods.
- [ ] Support more model variants.

41 changes: 34 additions & 7 deletions examples/llava/llava-surgery-v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,9 @@ def clean_vision_tower_from_checkpoint(checkpoint_path):
# file_type = 'pytorch'
model_path = os.path.dirname(checkpoint_path)
print(f"Searching for vision tower tensors in {checkpoint_path}")
clip_tensors = [k for k, v in checkpoint.items() if (k.startswith("model.vision_tower") or k.startswith("vit."))]
clip_tensors = [k for k, v in checkpoint.items() if (k.startswith("model.vision_embed_tokens.img_processor.vision_model") or \
(k.startswith("model.vision_tower")) or \
(k.startswith("vit.")))]

if len(clip_tensors) > 0:
print(f"Found {len(clip_tensors)} tensors to extract from {checkpoint_path}")
Expand Down Expand Up @@ -83,10 +85,13 @@ def find_relevant_checkpoints(checkpoint_paths, newline_criteria, projector):
return newline_checkpoint_path, projector_checkpoint_path

def newline_criteria(checkpoint):
return any(k.startswith("model.image_newline") for k in checkpoint.keys())
return any(k.startswith("model.vision_embed_tokens.sub_GN") or \
k.startswith("model.image_newline") for k in checkpoint.keys())

def proj_criteria(checkpoint):
return any(k.startswith("model.mm_projector") or k.startswith("vision_proj.") for k in checkpoint.keys())
return any(k.startswith("model.vision_embed_tokens.img_projection") or \
k.startswith("vision_proj.") or \
k.startswith("model.mm_projector") for k in checkpoint.keys())


# Command-line interface setup
Expand Down Expand Up @@ -121,14 +126,16 @@ def proj_criteria(checkpoint):
if newline_checkpoint_path is not None:
print(f"Taking newline from {newline_checkpoint_path}")
first_checkpoint, file_type = load_model(newline_checkpoint_path)
first_mm_tensors = [k for k, v in first_checkpoint.items() if k.startswith("model.image_newline")]
first_mm_tensors = [k for k, v in first_checkpoint.items() if k.startswith("model.vision_embed_tokens.sub_GN") or k.startswith("model.image_newline")]

# Load the checkpoint
mm_tensors = []
last_checkpoint = None
if projector_checkpoint_path is not None:
last_checkpoint, file_type = load_model(projector_checkpoint_path)
mm_tensors = [k for k, v in last_checkpoint.items() if k.startswith("model.mm_projector") or k.startswith("vision_proj.")]
mm_tensors = [k for k, v in last_checkpoint.items() if (k.startswith("model.vision_embed_tokens.img_projection")) or \
(k.startswith("vision_proj.")) or \
(k.startswith("model.mm_projector"))]

if len(mm_tensors) == 0:
if last_checkpoint is not None:
Expand All @@ -144,8 +151,28 @@ def proj_criteria(checkpoint):
projector = {}
for name in mm_tensors:
projector[name] = last_checkpoint[name].float()
for name in first_mm_tensors:
projector[name] = first_checkpoint[name].float()

def rename_keys(d, prefix):
new_dict = {}
for key, value in d.items():
parts = key.split('.')
new_key = f"{prefix}.{parts[-2]}.{parts[-1]}"
new_dict[new_key] = value
return new_dict

if list(projector.keys())[0].startswith("mm") is False:

print("-------------------------------")
print("PHI3V clip implicit conversion")
print("-------------------------------")

projector = rename_keys(projector, "mm")

for name in first_mm_tensors:
projector["model.image_newline"] = first_checkpoint[name].float()[0, 0, 0, :]

print("Updated projector keys to match LLAVA clip schema")
print(projector)

if len(projector) > 0:
save_model(projector, f"{args.model}/llava.projector", 'pytorch')
Expand Down
79 changes: 79 additions & 0 deletions examples/llava/phi3-weight-transfer.py
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

consider to put this entire logic into llava_surgery_v2.py

Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import argparse
import json
import os

import torch
from safetensors.torch import save_file
from transformers import AutoModelForCausalLM


def main(args):

# https://stackoverflow.com/questions/67689219/copy-one-layers-weights-from-one-huggingface-bert-model-to-another

phi3_vision = AutoModelForCausalLM.from_pretrained(args.phi3v_base_path,\
device_map="auto",\
trust_remote_code=True,\
torch_dtype=torch.float16,\
_attn_implementation='eager')

print("PHI3 VISION LOADED IN MEMORY")

phi3_base = AutoModelForCausalLM.from_pretrained(args.phi3_instruct_base_path,\
device_map="auto",\
trust_remote_code=True,\
torch_dtype=torch.float16,\
_attn_implementation='eager')

print("PHI3 BASE LOADED IN MEMORY")

phi3_vision_layers = dict(phi3_vision.named_parameters())
phi3_base_layers = dict(phi3_base.named_parameters())

parts = list(set(phi3_vision_layers.keys()) & set(phi3_base_layers.keys()))

print("----------------------------------------------------")
print("before transfer")
print(dict(phi3_vision.named_parameters())["model.layers.19.mlp.gate_up_proj.weight"] == \
dict(phi3_base.named_parameters())["model.layers.19.mlp.gate_up_proj.weight"])
print("----------------------------------------------------")

for part in parts:
phi3_base_layers[part].data.copy_(phi3_vision_layers[part].data)
# target # source

print("----------------------------------------------------")
print("after transfer")
print(dict(phi3_vision.named_parameters())["model.layers.19.mlp.gate_up_proj.weight"] == \
dict(phi3_base.named_parameters())["model.layers.19.mlp.gate_up_proj.weight"])
print("----------------------------------------------------")

# save updated model weights
outfile = "phi3-instruct-vision-weight-transfer.safetensors"
outpath = os.path.join(args.phi3_instruct_base_path, outfile)
save_file(phi3_base_layers, outpath)
print(f"updates .safetensors saved to {outpath}")

# update safetensors index config
weight_index_path = os.path.join(args.phi3_instruct_base_path, "model.safetensors.index.json")

with open(weight_index_path, "r") as f:
index_data = json.load(f)

for k,v in index_data["weight_map"].items():
if v != "phi3-instruct-vision-weight-transfer.safetensors":
index_data["weight_map"][k] = outfile

with open(weight_index_path, "w") as f:
json.dump(index_data, f)

print(f"hf saftensor mapping updated!")

if __name__ == '__main__':

parser = argparse.ArgumentParser(description="script to copy weights from PHI3V language model to PHI3-instruct")

parser.add_argument("--phi3-instruct-base-path", type=str, default="microsoft/Phi-3-mini-128k-instruct", help="model path or model card for PHI3-instruct")
parser.add_argument("--phi3v-base-path", type=str, default="microsoft/Phi-3-vision-128k-instruct", help="model path or model card for PHI3V")

main(parser.parse_args())
8 changes: 4 additions & 4 deletions ggml-metal.m
Original file line number Diff line number Diff line change
Expand Up @@ -779,7 +779,7 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
case GGML_OP_LEAKY_RELU:
return true;
case GGML_OP_FLASH_ATTN_EXT:
if (op->src[1]->type != GGML_TYPE_F16) {
if (op->src[1]->type != GGML_TYPE_F16) {
return false;
}
if (op->src[2]->type != GGML_TYPE_F16) {
Expand Down Expand Up @@ -1523,10 +1523,10 @@ static enum ggml_status ggml_metal_graph_compute(
} break;
case GGML_OP_MUL_MAT:
{
GGML_ASSERT(ne00 == ne10);
// GGML_ASSERT(ne00 == ne10);

GGML_ASSERT(ne12 % ne02 == 0);
GGML_ASSERT(ne13 % ne03 == 0);
// GGML_ASSERT(ne12 % ne02 == 0);
// GGML_ASSERT(ne13 % ne03 == 0);

const uint r2 = ne12/ne02;
const uint r3 = ne13/ne03;
Expand Down
4 changes: 2 additions & 2 deletions ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -5290,8 +5290,8 @@ struct ggml_tensor * ggml_mul_mat(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b) {
GGML_ASSERT(ggml_can_mul_mat(a, b));
GGML_ASSERT(!ggml_is_transposed(a));
// GGML_ASSERT(ggml_can_mul_mat(a, b));
// GGML_ASSERT(!ggml_is_transposed(a));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These asserts should not be removed. If you hit them, then there is most likely something wrong with the input data

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I found the issue,
Screenshot 2024-06-03 at 12 02 51 PM
tensor b should be [1024, 576] -> [4096, 576]

For LLAVA these dimensions are right:
Screenshot 2024-06-03 at 12 04 29 PM

But for phi3v we need the mm_projector weight tensor to be [4096, 576]:
Screenshot 2024-06-03 at 12 05 11 PM

Any idea on how to fix this?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


bool is_node = false;

Expand Down