diff --git a/.github/workflows/build-riscv-native.yml b/.github/workflows/build-riscv-native.yml deleted file mode 100644 index a3a0b0d6638..00000000000 --- a/.github/workflows/build-riscv-native.yml +++ /dev/null @@ -1,120 +0,0 @@ -name: Build on RISCV Linux Machine by Cloud-V -on: - pull_request: - workflow_dispatch: - workflow_call: - -jobs: - debian-13-riscv64-native: # Bianbu 2.2 - runs-on: [self-hosted, RISCV64] - - steps: - - name: Install prerequisites - run: | - sudo apt-get update || true - sudo apt-get install -y libatomic1 - - uses: actions/checkout@v4 - - name: Setup Riscv - run: | - sudo apt-get update || true - sudo apt-get install -y --no-install-recommends \ - build-essential \ - gcc-14-riscv64-linux-gnu \ - g++-14-riscv64-linux-gnu \ - ccache \ - cmake - - - name: Setup ccache - run: | - mkdir -p $HOME/.ccache - ccache -M 5G -d $HOME/.ccache - export CCACHE_LOGFILE=/home/runneruser/ccache_debug/ccache.log - export CCACHE_DEBUGDIR="/home/runneruser/ccache_debug" - echo "$GITHUB_WORKSPACE" - echo "CCACHE_LOGFILE=$CCACHE_LOGFILE" >> $GITHUB_ENV - echo "CCACHE_DEBUGDIR=$CCACHE_DEBUGDIR" >> $GITHUB_ENV - echo "CCACHE_BASEDIR=$GITHUB_WORKSPACE" >> $GITHUB_ENV - echo "CCACHE_DIR=$HOME/.ccache" >> $GITHUB_ENV - - - name: Build - run: | - cmake -B build \ - -DLLAMA_CURL=OFF \ - -DCMAKE_BUILD_TYPE=Release \ - -DGGML_OPENMP=OFF \ - -DLLAMA_BUILD_EXAMPLES=ON \ - -DLLAMA_BUILD_TOOLS=ON \ - -DLLAMA_BUILD_TESTS=OFF \ - -DCMAKE_SYSTEM_NAME=Linux \ - -DCMAKE_SYSTEM_PROCESSOR=riscv64 \ - -DCMAKE_C_COMPILER=riscv64-linux-gnu-gcc-14 \ - -DCMAKE_CXX_COMPILER=riscv64-linux-gnu-g++-14 \ - -DCMAKE_C_COMPILER_LAUNCHER=ccache \ - -DCMAKE_CXX_COMPILER_LAUNCHER=ccache \ - -DCMAKE_POSITION_INDEPENDENT_CODE=ON \ - -DCMAKE_FIND_ROOT_PATH=/usr/lib/riscv64-linux-gnu \ - -DCMAKE_FIND_ROOT_PATH_MODE_PROGRAM=NEVER \ - -DCMAKE_FIND_ROOT_PATH_MODE_LIBRARY=ONLY \ - -DCMAKE_FIND_ROOT_PATH_MODE_INCLUDE=BOTH - - cmake --build build --config Release -j $(nproc) - - # debian-13-riscv64-spacemit-ime-native: # Bianbu 2.2 - # runs-on: [self-hosted, RISCV64] - - # steps: - # - name: Install prerequisites - # run: | - # sudo apt-get update || true - # sudo apt-get install -y libatomic1 - # - uses: actions/checkout@v4 - # - name: Setup Riscv - # run: | - # sudo apt-get update || true - # sudo apt-get install -y --no-install-recommends \ - # build-essential \ - # gcc-14-riscv64-linux-gnu \ - # g++-14-riscv64-linux-gnu \ - # ccache \ - # cmake - # sudo apt-get upgrade binutils -y - - # - name: Setup ccache - # run: | - # mkdir -p $HOME/.ccache - # ccache -M 5G -d $HOME/.ccache - # export CCACHE_LOGFILE=/home/runneruser/ccache_debug/ccache.log - # export CCACHE_DEBUGDIR="/home/runneruser/ccache_debug" - # echo "$GITHUB_WORKSPACE" - # echo "CCACHE_LOGFILE=$CCACHE_LOGFILE" >> $GITHUB_ENV - # echo "CCACHE_DEBUGDIR=$CCACHE_DEBUGDIR" >> $GITHUB_ENV - # echo "CCACHE_BASEDIR=$GITHUB_WORKSPACE" >> $GITHUB_ENV - # echo "CCACHE_DIR=$HOME/.ccache" >> $GITHUB_ENV - - # - name: Build - # run: | - # cmake -B build \ - # -DLLAMA_CURL=OFF \ - # -DCMAKE_BUILD_TYPE=Release \ - # -DGGML_OPENMP=OFF \ - # -DLLAMA_BUILD_EXAMPLES=ON \ - # -DLLAMA_BUILD_TOOLS=ON \ - # -DLLAMA_BUILD_TESTS=OFF \ - # -DCMAKE_SYSTEM_NAME=Linux \ - # -DCMAKE_SYSTEM_PROCESSOR=riscv64 \ - # -DCMAKE_C_COMPILER=riscv64-linux-gnu-gcc-14 \ - # -DCMAKE_CXX_COMPILER=riscv64-linux-gnu-g++-14 \ - # -DCMAKE_C_COMPILER_LAUNCHER=ccache \ - # -DCMAKE_CXX_COMPILER_LAUNCHER=ccache \ - # -DCMAKE_POSITION_INDEPENDENT_CODE=ON \ - # -DCMAKE_FIND_ROOT_PATH=/usr/lib/riscv64-linux-gnu \ - # -DCMAKE_FIND_ROOT_PATH_MODE_PROGRAM=NEVER \ - # -DCMAKE_FIND_ROOT_PATH_MODE_LIBRARY=ONLY \ - # -DCMAKE_FIND_ROOT_PATH_MODE_INCLUDE=BOTH \ - # -DGGML_RVV=ON \ - # -DGGML_RV_ZFH=ON \ - # -DGGML_RV_ZICBOP=ON \ - # -DGGML_CPU_RISCV64_SPACEMIT=ON \ - # -DRISCV64_SPACEMIT_IME_SPEC=RISCV64_SPACEMIT_IME1 - - # cmake --build build --config Release -j $(nproc) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index eee42759fc9..49e836d9b20 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -547,6 +547,46 @@ jobs: # This is using llvmpipe and runs slower than other backends ctest -L main --verbose --timeout 3600 + ubuntu-24-wasm-webgpu: + runs-on: ubuntu-24.04 + + steps: + - name: Clone + id: checkout + uses: actions/checkout@v4 + + - name: ccache + uses: ggml-org/ccache-action@v1.2.16 + with: + key: ubuntu-latest-wasm-webgpu + evict-old-files: 1d + + - name: Install Emscripten + run: | + git clone https://github.com/emscripten-core/emsdk.git + cd emsdk + ./emsdk install latest + ./emsdk activate latest + + - name: Fetch emdawnwebgpu + run: | + DAWN_TAG="v20251027.212519" + EMDAWN_PKG="emdawnwebgpu_pkg-${DAWN_TAG}.zip" + echo "Downloading ${EMDAWN_PKG}" + curl -L -o emdawn.zip \ + "https://github.com/google/dawn/releases/download/${DAWN_TAG}/${EMDAWN_PKG}" + unzip emdawn.zip + + - name: Build WASM WebGPU + run: | + source emsdk/emsdk_env.sh + emcmake cmake -B build-wasm \ + -DGGML_WEBGPU=ON \ + -DLLAMA_CURL=OFF \ + -DEMDAWNWEBGPU_DIR=emdawnwebgpu_pkg + + cmake --build build-wasm --target test-backend-ops -j $(nproc) + ubuntu-22-cmake-hip: runs-on: ubuntu-22.04 container: rocm/dev-ubuntu-22.04:6.1.2 @@ -1642,6 +1682,337 @@ jobs: run: | GG_BUILD_KLEIDIAI=1 GG_BUILD_EXTRA_TESTS_0=1 bash ./ci/run.sh ./tmp/results ./tmp/mnt + ubuntu-cpu-cmake-riscv64-native: + runs-on: RISCV64 + + steps: + - name: Install dependencies + run: | + sudo apt-get update + + # Install necessary packages + sudo apt-get install -y libatomic1 libtsan2 gcc-14 g++-14 rustup cmake build-essential libssl-dev wget ccache + + # Set gcc-14 and g++-14 as the default compilers + sudo update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-14 100 + sudo update-alternatives --install /usr/bin/g++ g++ /usr/bin/g++-14 100 + sudo ln -sf /usr/bin/gcc-14 /usr/bin/gcc + sudo ln -sf /usr/bin/g++-14 /usr/bin/g++ + + # Install Rust stable version + rustup install stable + rustup default stable + + - name: Clone + id: checkout + uses: actions/checkout@v4 + + - name: Check environment + run: | + uname -a + gcc --version + g++ --version + ldd --version + cmake --version + rustc --version + + - name: Setup ccache + run: | + # Set unique cache directory for this job + export CCACHE_DIR="$HOME/.ccache/cpu-cmake-rv64-native" + mkdir -p "$CCACHE_DIR" + + # Configure ccache for optimal performance + ccache --set-config=max_size=5G + ccache --set-config=compression=true + ccache --set-config=compression_level=6 + ccache --set-config=cache_dir="$CCACHE_DIR" + + # Enable more aggressive caching + ccache --set-config=sloppiness=file_macro,time_macros,include_file_mtime,include_file_ctime + ccache --set-config=hash_dir=false + + # Export for subsequent steps + echo "CCACHE_DIR=$CCACHE_DIR" >> $GITHUB_ENV + echo "PATH=/usr/lib/ccache:$PATH" >> $GITHUB_ENV + + - name: Build + id: cmake_build + run: | + cmake -B build \ + -DLLAMA_CURL=OFF \ + -DLLAMA_OPENSSL=ON \ + -DCMAKE_BUILD_TYPE=Release \ + -DGGML_OPENMP=OFF \ + -DLLAMA_BUILD_EXAMPLES=ON \ + -DLLAMA_BUILD_TOOLS=ON \ + -DLLAMA_BUILD_TESTS=ON \ + -DCMAKE_C_COMPILER_LAUNCHER=ccache \ + -DCMAKE_CXX_COMPILER_LAUNCHER=ccache \ + -DGGML_RPC=ON \ + -DCMAKE_C_COMPILER=riscv64-linux-gnu-gcc-14 \ + -DCMAKE_CXX_COMPILER=riscv64-linux-gnu-g++-14 + + cmake --build build --config Release -j $(nproc) + + - name: Test + id: cmake_test + run: | + cd build + ctest -L 'main|curl' --verbose --timeout 900 + + - name: Test llama2c conversion + id: llama2c_test + run: | + cd build + echo "Fetch tokenizer" + wget https://huggingface.co/karpathy/tinyllamas/resolve/main/stories260K/tok512.bin + echo "Fetch llama2c model" + wget https://huggingface.co/karpathy/tinyllamas/resolve/main/stories260K/stories260K.bin + ./bin/llama-convert-llama2c-to-ggml --copy-vocab-from-model ./tok512.bin --llama2c-model stories260K.bin --llama2c-output-model stories260K.gguf + ./bin/llama-cli -m stories260K.gguf -p "One day, Lily met a Shoggoth" -n 500 -c 256 + + ubuntu-cmake-sanitizer-riscv64-native: + runs-on: RISCV64 + + continue-on-error: true + + strategy: + matrix: + sanitizer: [ADDRESS, THREAD, UNDEFINED] + build_type: [Debug] + + steps: + - name: Install dependencies + run: | + sudo apt-get update + + # Install necessary packages + sudo apt-get install -y libatomic1 libtsan2 gcc-14 g++-14 rustup cmake build-essential wget ccache + + # Set gcc-14 and g++-14 as the default compilers + sudo update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-14 100 + sudo update-alternatives --install /usr/bin/g++ g++ /usr/bin/g++-14 100 + sudo ln -sf /usr/bin/gcc-14 /usr/bin/gcc + sudo ln -sf /usr/bin/g++-14 /usr/bin/g++ + + # Install Rust stable version + rustup install stable + rustup default stable + + - name: GCC version check + run: | + gcc --version + g++ --version + + - name: Clone + id: checkout + uses: actions/checkout@v4 + + - name: Setup ccache + run: | + # Unique cache directory per matrix combination + export CCACHE_DIR="$HOME/.ccache/sanitizer-${{ matrix.sanitizer }}-${{ matrix.build_type }}" + mkdir -p "$CCACHE_DIR" + + # Configure ccache + ccache --set-config=max_size=5G + ccache --set-config=compression=true + ccache --set-config=compression_level=6 + ccache --set-config=cache_dir="$CCACHE_DIR" + ccache --set-config=sloppiness=file_macro,time_macros,include_file_mtime,include_file_ctime + ccache --set-config=hash_dir=false + + # Export for subsequent steps + echo "CCACHE_DIR=$CCACHE_DIR" >> $GITHUB_ENV + echo "PATH=/usr/lib/ccache:$PATH" >> $GITHUB_ENV + + - name: Build + id: cmake_build + if: ${{ matrix.sanitizer != 'THREAD' }} + run: | + cmake -B build \ + -DLLAMA_CURL=OFF \ + -DCMAKE_BUILD_TYPE=${{ matrix.build_type }} \ + -DGGML_OPENMP=ON \ + -DLLAMA_BUILD_EXAMPLES=ON \ + -DLLAMA_BUILD_TOOLS=ON \ + -DLLAMA_BUILD_TESTS=OFF \ + -DCMAKE_C_COMPILER_LAUNCHER=ccache \ + -DCMAKE_CXX_COMPILER_LAUNCHER=ccache \ + -DLLAMA_SANITIZE_${{ matrix.sanitizer }}=ON \ + -DCMAKE_C_COMPILER=riscv64-linux-gnu-gcc-14 \ + -DCMAKE_CXX_COMPILER=riscv64-linux-gnu-g++-14 + + cmake --build build --config ${{ matrix.build_type }} -j $(nproc) + + - name: Build (no OpenMP) + id: cmake_build_no_openmp + if: ${{ matrix.sanitizer == 'THREAD' }} + run: | + cmake -B build \ + -DLLAMA_CURL=OFF \ + -DCMAKE_BUILD_TYPE=${{ matrix.build_type }} \ + -DGGML_OPENMP=OFF \ + -DLLAMA_BUILD_EXAMPLES=ON \ + -DLLAMA_BUILD_TOOLS=ON \ + -DLLAMA_BUILD_TESTS=OFF \ + -DCMAKE_C_COMPILER_LAUNCHER=ccache \ + -DCMAKE_CXX_COMPILER_LAUNCHER=ccache \ + -DLLAMA_SANITIZE_${{ matrix.sanitizer }}=ON \ + -DCMAKE_C_COMPILER=riscv64-linux-gnu-gcc-14 \ + -DCMAKE_CXX_COMPILER=riscv64-linux-gnu-g++-14 + + cmake --build build --config ${{ matrix.build_type }} -j $(nproc) + + - name: Test + id: cmake_test + run: | + cd build + ctest -L main --verbose --timeout 900 + + + ubuntu-llguidance-riscv64-native: + runs-on: RISCV64 + steps: + - name: Install dependencies + run: | + sudo apt-get update + + # Install necessary packages + sudo apt-get install -y libatomic1 libtsan2 gcc-14 g++-14 rustup cmake build-essential wget ccache + + # Set gcc-14 and g++-14 as the default compilers + sudo update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-14 100 + sudo update-alternatives --install /usr/bin/g++ g++ /usr/bin/g++-14 100 + sudo ln -sf /usr/bin/gcc-14 /usr/bin/gcc + sudo ln -sf /usr/bin/g++-14 /usr/bin/g++ + + # Install Rust stable version + rustup install stable + rustup default stable + + - name: GCC version check + run: | + gcc --version + g++ --version + + - name: Clone + id: checkout + uses: actions/checkout@v4 + + - name: Setup ccache + run: | + export CCACHE_DIR="$HOME/.ccache/llguidance-riscv64" + mkdir -p "$CCACHE_DIR" + + ccache --set-config=max_size=5G + ccache --set-config=compression=true + ccache --set-config=compression_level=6 + ccache --set-config=cache_dir="$CCACHE_DIR" + ccache --set-config=sloppiness=file_macro,time_macros,include_file_mtime,include_file_ctime + ccache --set-config=hash_dir=false + + echo "CCACHE_DIR=$CCACHE_DIR" >> $GITHUB_ENV + echo "PATH=/usr/lib/ccache:$PATH" >> $GITHUB_ENV + + - name: Build + id: cmake_build + run: | + cmake -B build \ + -DLLAMA_CURL=OFF \ + -DCMAKE_BUILD_TYPE=Release \ + -DGGML_OPENMP=OFF \ + -DLLAMA_BUILD_EXAMPLES=ON \ + -DLLAMA_BUILD_TOOLS=ON \ + -DLLAMA_BUILD_TESTS=OFF \ + -DCMAKE_C_COMPILER_LAUNCHER=ccache \ + -DCMAKE_CXX_COMPILER_LAUNCHER=ccache \ + -DLLAMA_LLGUIDANCE=ON \ + -DCMAKE_C_COMPILER=riscv64-linux-gnu-gcc-14 \ + -DCMAKE_CXX_COMPILER=riscv64-linux-gnu-g++-14 + + cmake --build build --config Release -j $(nproc) + + - name: Test + id: cmake_test + run: | + cd build + ctest -L main --verbose --timeout 900 + + + ubuntu-cmake-rpc-riscv64-native: + runs-on: RISCV64 + + continue-on-error: true + + steps: + - name: Install dependencies + run: | + sudo apt-get update + + # Install necessary packages + sudo apt-get install -y libatomic1 libtsan2 gcc-14 g++-14 rustup cmake build-essential libssl-dev wget ccache + + # Set gcc-14 and g++-14 as the default compilers + sudo update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-14 100 + sudo update-alternatives --install /usr/bin/g++ g++ /usr/bin/g++-14 100 + sudo ln -sf /usr/bin/gcc-14 /usr/bin/gcc + sudo ln -sf /usr/bin/g++-14 /usr/bin/g++ + + # Install Rust stable version + rustup install stable + rustup default stable + + - name: GCC version check + run: | + gcc --version + g++ --version + + - name: Clone + id: checkout + uses: actions/checkout@v4 + + - name: Setup ccache + run: | + export CCACHE_DIR="$HOME/.ccache/rpc-riscv64" + mkdir -p "$CCACHE_DIR" + + ccache --set-config=max_size=5G + ccache --set-config=compression=true + ccache --set-config=compression_level=6 + ccache --set-config=cache_dir="$CCACHE_DIR" + ccache --set-config=sloppiness=file_macro,time_macros,include_file_mtime,include_file_ctime + ccache --set-config=hash_dir=false + + echo "CCACHE_DIR=$CCACHE_DIR" >> $GITHUB_ENV + echo "PATH=/usr/lib/ccache:$PATH" >> $GITHUB_ENV + + - name: Build + id: cmake_build + run: | + cmake -B build \ + -DLLAMA_CURL=OFF \ + -DLLAMA_OPENSSL=ON \ + -DCMAKE_BUILD_TYPE=Release \ + -DGGML_OPENMP=OFF \ + -DLLAMA_BUILD_EXAMPLES=ON \ + -DLLAMA_BUILD_TOOLS=ON \ + -DLLAMA_BUILD_TESTS=ON \ + -DCMAKE_C_COMPILER_LAUNCHER=ccache \ + -DCMAKE_CXX_COMPILER_LAUNCHER=ccache \ + -DCMAKE_C_COMPILER=riscv64-linux-gnu-gcc-14 \ + -DCMAKE_CXX_COMPILER=riscv64-linux-gnu-g++-14 \ + -DGGML_RPC=ON + + cmake --build build --config Release -j $(nproc) + + - name: Test + id: cmake_test + run: | + cd build + ctest -L main --verbose + ggml-ci-arm64-graviton4-kleidiai: runs-on: ah-ubuntu_22_04-c8g_8x diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index a0a13f38400..da1363a7982 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -728,58 +728,6 @@ jobs: path: llama-${{ steps.tag.outputs.name }}-xcframework.tar.gz name: llama-${{ steps.tag.outputs.name }}-xcframework.tar.gz - openEuler-cann: - strategy: - matrix: - arch: [x86, aarch64] - chip_type: ['910b', '310p'] - build: ['Release'] - runs-on: ${{ matrix.arch == 'aarch64' && 'ubuntu-24.04-arm' || 'ubuntu-24.04' }} - container: ascendai/cann:${{ matrix.chip_type == '910b' && '8.3.rc1.alpha001-910b-openeuler22.03-py3.11' || '8.2.rc1-310p-openeuler22.03-py3.11' }} - steps: - - name: Checkout - uses: actions/checkout@v4 - with: - fetch-depth: 0 - - - name: Dependencies - run: | - yum update -y - yum install -y git gcc gcc-c++ make cmake libcurl-devel - git config --global --add safe.directory "$GITHUB_WORKSPACE" - - - name: Build - run: | - export LD_LIBRARY_PATH=${ASCEND_TOOLKIT_HOME}/lib64:${ASCEND_TOOLKIT_HOME}/$(uname -m)-linux/devlib/:${LD_LIBRARY_PATH} - - cmake -S . -B build \ - -DCMAKE_BUILD_TYPE=${{ matrix.build }} \ - -DGGML_CANN=on \ - -DSOC_TYPE=ascend${{ matrix.chip_type }} - cmake --build build -j $(nproc) - - - name: Determine tag name - id: tag - uses: ./.github/actions/get-tag-name - - - name: Pack artifacts - run: | - cp LICENSE ./build/bin/ - zip -y -r llama-${{ steps.tag.outputs.name }}-bin-${{ matrix.chip_type }}-openEuler-${{ matrix.arch }}.zip ./build/bin/* - tar -czvf llama-${{ steps.tag.outputs.name }}-bin-${{ matrix.chip_type }}-openEuler-${{ matrix.arch }}.tar.gz -C ./build/bin . - - - name: Upload artifacts (zip) - uses: actions/upload-artifact@v4 - with: - path: llama-${{ steps.tag.outputs.name }}-bin-${{ matrix.chip_type }}-openEuler-${{ matrix.arch }}.zip - name: llama-bin-${{ matrix.chip_type }}-openEuler-${{ matrix.arch }}.zip - - - name: Upload artifacts (tar) - uses: actions/upload-artifact@v4 - with: - path: llama-${{ steps.tag.outputs.name }}-bin-${{ matrix.chip_type }}-openEuler-${{ matrix.arch }}.tar.gz - name: llama-bin-${{ matrix.chip_type }}-openEuler-${{ matrix.arch }}.tar.gz - release: if: ${{ ( github.event_name == 'push' && github.ref == 'refs/heads/master' ) || github.event.inputs.create_release == 'true' }} @@ -801,7 +749,6 @@ jobs: - macOS-arm64 - macOS-x64 - ios-xcode-build - - openEuler-cann steps: - name: Clone @@ -869,6 +816,12 @@ jobs: > [!WARNING] > **Release Format Update**: Linux releases will soon use .tar.gz archives instead of .zip. Please make the necessary changes to your deployment scripts. +
+ + ${{ github.event.head_commit.message }} + +
+ **macOS/iOS:** - [macOS Apple Silicon (arm64)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-macos-arm64.tar.gz) - [macOS Intel (x64)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-macos-x64.tar.gz) @@ -887,18 +840,6 @@ jobs: - [Windows x64 (SYCL)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-win-sycl-x64.zip) - [Windows x64 (HIP)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-win-hip-radeon-x64.zip) - **openEuler:** - - [openEuler x86 (310p)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-310p-openEuler-x86.tar.gz) - - [openEuler x86 (910b)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-910b-openEuler-x86.tar.gz) - - [openEuler aarch64 (310p)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-310p-openEuler-aarch64.tar.gz) - - [openEuler aarch64 (910b)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-910b-openEuler-aarch64.tar.gz) - -
- - ${{ github.event.head_commit.message }} - -
- - name: Upload release id: upload_release uses: actions/github-script@v3 diff --git a/.gitignore b/.gitignore index 8575a141c40..428f0841100 100644 --- a/.gitignore +++ b/.gitignore @@ -134,3 +134,5 @@ poetry.toml # IDE /*.code-workspace /.windsurf/ +# emscripten +a.out.* diff --git a/CMakeLists.txt b/CMakeLists.txt index 3278c4a72c1..c231ec0e3fa 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -33,10 +33,24 @@ endif() option(LLAMA_USE_SYSTEM_GGML "Use system libggml" OFF) +option(LLAMA_WASM_MEM64 "llama: use 64-bit memory in WASM builds" ON) + if (EMSCRIPTEN) set(BUILD_SHARED_LIBS_DEFAULT OFF) - option(LLAMA_WASM_SINGLE_FILE "llama: embed WASM inside the generated llama.js" ON) + # Use 64-bit memory to support backend_get_memory queries + # TODO: analyze performance impact, see https://spidermonkey.dev/blog/2025/01/15/is-memory64-actually-worth-using + if (LLAMA_WASM_MEM64) + add_compile_options("-sMEMORY64=1") + add_link_options("-sMEMORY64=1") + endif() + add_link_options("-sALLOW_MEMORY_GROWTH=1") + + option(LLAMA_WASM_SINGLE_FILE "llama: embed WASM inside the generated llama.js" OFF) + option(LLAMA_BUILD_HTML "llama: build HTML file" ON) + if (LLAMA_BUILD_HTML) + set(CMAKE_EXECUTABLE_SUFFIX ".html") + endif() else() if (MINGW) set(BUILD_SHARED_LIBS_DEFAULT OFF) @@ -58,6 +72,12 @@ if (MSVC) add_compile_options("$<$:/bigobj>") endif() +if (LLAMA_STANDALONE) + # enable parallel builds for msbuild + list(APPEND CMAKE_VS_GLOBALS UseMultiToolTask=true) + list(APPEND CMAKE_VS_GLOBALS EnforceProcessCountAcrossBuilds=true) +endif() + if (CMAKE_SYSTEM_NAME STREQUAL "iOS") set(LLAMA_TOOLS_INSTALL_DEFAULT OFF) else() @@ -179,11 +199,6 @@ if (NOT TARGET ggml AND NOT LLAMA_USE_SYSTEM_GGML) # ... otherwise assume ggml is added by a parent CMakeLists.txt endif() -if (MINGW) - # Target Windows 8 for PrefetchVirtualMemory - add_compile_definitions(_WIN32_WINNT=${GGML_WIN_VER}) -endif() - # # build the library # diff --git a/CODEOWNERS b/CODEOWNERS index 6ef6c0489f1..450191b7343 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -7,16 +7,19 @@ /ci/ @ggerganov /cmake/ @ggerganov /common/CMakeLists.txt @ggerganov -/common/arg.* @ggerganov @ericcurtin +/common/arg.* @ggerganov /common/base64.hpp.* @ggerganov /common/build-info.* @ggerganov +/common/chat-peg-parser.* @aldehir /common/common.* @ggerganov /common/console.* @ggerganov /common/http.* @angt /common/llguidance.* @ggerganov /common/log.* @ggerganov +/common/peg-parser.* @aldehir /common/sampling.* @ggerganov /common/speculative.* @ggerganov +/common/unicode.* @aldehir /convert_*.py @CISC /examples/batched.swift/ @ggerganov /examples/batched/ @ggerganov @@ -87,8 +90,7 @@ /tools/perplexity/ @ggerganov /tools/quantize/ @ggerganov /tools/rpc/ @rgerganov -/tools/run/ @ericcurtin -/tools/server/* @ngxson @ggerganov @ericcurtin # no subdir +/tools/server/* @ngxson @ggerganov # no subdir /tools/server/webui/ @allozaur /tools/tokenize/ @ggerganov /tools/tts/ @ggerganov diff --git a/ci/run.sh b/ci/run.sh index 1dd65adeaac..83b2603e821 100755 --- a/ci/run.sh +++ b/ci/run.sh @@ -45,7 +45,7 @@ sd=`dirname $0` cd $sd/../ SRC=`pwd` -CMAKE_EXTRA="-DLLAMA_FATAL_WARNINGS=ON -DLLAMA_CURL=ON -DGGML_SCHED_NO_REALLOC=ON" +CMAKE_EXTRA="-DLLAMA_FATAL_WARNINGS=${LLAMA_FATAL_WARNINGS:-ON} -DLLAMA_CURL=ON -DGGML_SCHED_NO_REALLOC=ON" if [ ! -z ${GG_BUILD_METAL} ]; then CMAKE_EXTRA="${CMAKE_EXTRA} -DGGML_METAL=ON" diff --git a/common/CMakeLists.txt b/common/CMakeLists.txt index bb168e8358a..377b26846b6 100644 --- a/common/CMakeLists.txt +++ b/common/CMakeLists.txt @@ -52,6 +52,8 @@ add_library(${TARGET} STATIC chat-parser.h chat-parser-xml-toolcall.h chat-parser-xml-toolcall.cpp + chat-peg-parser.cpp + chat-peg-parser.h chat.cpp chat.h common.cpp @@ -69,12 +71,16 @@ add_library(${TARGET} STATIC log.h ngram-cache.cpp ngram-cache.h + peg-parser.cpp + peg-parser.h regex-partial.cpp regex-partial.h sampling.cpp sampling.h speculative.cpp speculative.h + unicode.cpp + unicode.h ) if (BUILD_SHARED_LIBS) diff --git a/common/arg.cpp b/common/arg.cpp index 52094e3f10a..cb2b4c603bc 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -30,6 +30,7 @@ #include // for hardware_concurrency #include +#ifndef __EMSCRIPTEN__ #ifdef __linux__ #include #elif defined(_WIN32) @@ -41,6 +42,8 @@ #else #include #endif +#endif + #define LLAMA_MAX_URL_LENGTH 2084 // Maximum URL Length in Chrome: 2083 using json = nlohmann::ordered_json; @@ -1226,7 +1229,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex [](common_params & params) { params.warmup = false; } - ).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_EMBEDDING, LLAMA_EXAMPLE_RETRIEVAL, LLAMA_EXAMPLE_PERPLEXITY})); + ).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_MTMD, LLAMA_EXAMPLE_EMBEDDING, LLAMA_EXAMPLE_RETRIEVAL, LLAMA_EXAMPLE_PERPLEXITY})); add_opt(common_arg( {"--spm-infill"}, string_format( @@ -2488,12 +2491,29 @@ common_params_context common_params_parser_init(common_params & params, llama_ex "path to save slot kv cache (default: disabled)", [](common_params & params, const std::string & value) { params.slot_save_path = value; + if (!fs_is_directory(params.slot_save_path)) { + throw std::invalid_argument("not a directory: " + value); + } // if doesn't end with DIRECTORY_SEPARATOR, add it if (!params.slot_save_path.empty() && params.slot_save_path[params.slot_save_path.size() - 1] != DIRECTORY_SEPARATOR) { params.slot_save_path += DIRECTORY_SEPARATOR; } } ).set_examples({LLAMA_EXAMPLE_SERVER})); + add_opt(common_arg( + {"--media-path"}, "PATH", + "directory for loading local media files; files can be accessed via file:// URLs using relative paths (default: disabled)", + [](common_params & params, const std::string & value) { + params.media_path = value; + if (!fs_is_directory(params.media_path)) { + throw std::invalid_argument("not a directory: " + value); + } + // if doesn't end with DIRECTORY_SEPARATOR, add it + if (!params.media_path.empty() && params.media_path[params.media_path.size() - 1] != DIRECTORY_SEPARATOR) { + params.media_path += DIRECTORY_SEPARATOR; + } + } + ).set_examples({LLAMA_EXAMPLE_SERVER})); add_opt(common_arg( {"--models-dir"}, "PATH", "directory containing models for the router server (default: disabled)", @@ -3208,5 +3228,17 @@ common_params_context common_params_parser_init(common_params & params, llama_ex } ).set_examples({LLAMA_EXAMPLE_SERVER})); + // Phi3-5 Vision num_crops + add_opt(common_arg( + {"--num-crops"}, "N", + string_format("number of crops for Phi-3-Vision image processing (default: loaded from model"), + [](common_params & params, int value) { + if (value < 0) { + throw std::runtime_error("num_crops must be positive"); + } + params.num_crops = value; + } + )); + return ctx_arg; } diff --git a/common/chat-parser.cpp b/common/chat-parser.cpp index 301f439a6f9..fe3e80037f4 100644 --- a/common/chat-parser.cpp +++ b/common/chat-parser.cpp @@ -1,6 +1,8 @@ #include "chat-parser.h" +#include "chat-peg-parser.h" #include "common.h" #include "log.h" +#include "peg-parser.h" #include "regex-partial.h" #include @@ -1483,6 +1485,11 @@ static void common_chat_parse(common_chat_msg_parser & builder) { } common_chat_msg common_chat_parse(const std::string & input, bool is_partial, const common_chat_syntax & syntax) { + if (syntax.format == COMMON_CHAT_FORMAT_PEG_SIMPLE || + syntax.format == COMMON_CHAT_FORMAT_PEG_NATIVE || + syntax.format == COMMON_CHAT_FORMAT_PEG_CONSTRUCTED) { + return common_chat_peg_parse(syntax.parser, input, is_partial, syntax); + } common_chat_msg_parser builder(input, is_partial, syntax); try { common_chat_parse(builder); @@ -1500,3 +1507,36 @@ common_chat_msg common_chat_parse(const std::string & input, bool is_partial, co } return msg; } + +common_chat_msg common_chat_peg_parse(const common_peg_arena & parser, const std::string & input, bool is_partial, const common_chat_syntax & syntax) { + if (parser.empty()) { + throw std::runtime_error("Failed to parse due to missing parser definition."); + } + + LOG_DBG("Parsing input with format %s: %s\n", common_chat_format_name(syntax.format), input.c_str()); + + common_peg_parse_context ctx(input, is_partial); + auto result = parser.parse(ctx); + if (result.fail()) { + throw std::runtime_error(std::string("Failed to parse input at pos ") + std::to_string(result.end)); + } + + common_chat_msg msg; + msg.role = "assistant"; + + if (syntax.format == COMMON_CHAT_FORMAT_PEG_NATIVE) { + auto mapper = common_chat_peg_native_mapper(msg); + mapper.from_ast(ctx.ast, result); + } else if (syntax.format == COMMON_CHAT_FORMAT_PEG_CONSTRUCTED) { + auto mapper = common_chat_peg_constructed_mapper(msg); + mapper.from_ast(ctx.ast, result); + } else { + // Generic mapper + auto mapper = common_chat_peg_mapper(msg); + mapper.from_ast(ctx.ast, result); + } + if (!is_partial) { + LOG_DBG("Parsed message: %s\n", common_chat_msgs_to_json_oaicompat({msg}).at(0).dump().c_str()); + } + return msg; +} diff --git a/common/chat-peg-parser.cpp b/common/chat-peg-parser.cpp new file mode 100644 index 00000000000..74a7b6a46dc --- /dev/null +++ b/common/chat-peg-parser.cpp @@ -0,0 +1,114 @@ +#include "chat-peg-parser.h" + +#include + +using json = nlohmann::json; + +static std::string_view trim_trailing_space(std::string_view sv) { + while (!sv.empty() && std::isspace(static_cast(sv.back()))) { + sv.remove_suffix(1); + } + return sv; +} + +void common_chat_peg_mapper::from_ast(const common_peg_ast_arena & arena, const common_peg_parse_result & result) { + arena.visit(result, [this](const common_peg_ast_node & node) { + map(node); + }); +} + +void common_chat_peg_mapper::map(const common_peg_ast_node & node) { + bool is_reasoning = node.tag == common_chat_peg_builder::REASONING; + bool is_content = node.tag == common_chat_peg_builder::CONTENT; + + if (is_reasoning) { + result.reasoning_content = std::string(trim_trailing_space(node.text)); + } + + if (is_content) { + result.content = std::string(trim_trailing_space(node.text)); + } +} + +void common_chat_peg_native_mapper::map(const common_peg_ast_node & node) { + common_chat_peg_mapper::map(node); + + bool is_tool_open = node.tag == common_chat_peg_native_builder::TOOL_OPEN; + bool is_tool_name = node.tag == common_chat_peg_native_builder::TOOL_NAME; + bool is_tool_id = node.tag == common_chat_peg_native_builder::TOOL_ID; + bool is_tool_args = node.tag == common_chat_peg_native_builder::TOOL_ARGS; + + if (is_tool_open) { + result.tool_calls.emplace_back(); + current_tool = &result.tool_calls.back(); + } + + if (is_tool_id && current_tool) { + current_tool->id = std::string(trim_trailing_space(node.text)); + } + + if (is_tool_name && current_tool) { + current_tool->name = std::string(trim_trailing_space(node.text)); + } + + if (is_tool_args && current_tool) { + current_tool->arguments = std::string(trim_trailing_space(node.text)); + } +} + +void common_chat_peg_constructed_mapper::map(const common_peg_ast_node & node) { + common_chat_peg_mapper::map(node); + + bool is_tool_open = node.tag == common_chat_peg_constructed_builder::TOOL_OPEN; + bool is_tool_name = node.tag == common_chat_peg_constructed_builder::TOOL_NAME; + bool is_tool_close = node.tag == common_chat_peg_constructed_builder::TOOL_CLOSE; + bool is_arg_open = node.tag == common_chat_peg_constructed_builder::TOOL_ARG_OPEN; + bool is_arg_close = node.tag == common_chat_peg_constructed_builder::TOOL_ARG_CLOSE; + bool is_arg_name = node.tag == common_chat_peg_constructed_builder::TOOL_ARG_NAME; + bool is_arg_string = node.tag == common_chat_peg_constructed_builder::TOOL_ARG_STRING_VALUE; + bool is_arg_json = node.tag == common_chat_peg_constructed_builder::TOOL_ARG_JSON_VALUE; + + if (is_tool_open) { + result.tool_calls.emplace_back(); + current_tool = &result.tool_calls.back(); + arg_count = 0; + } + + if (is_tool_name) { + current_tool->name = std::string(node.text); + current_tool->arguments = "{"; + } + + if (is_arg_open) { + needs_closing_quote = false; + } + + if (is_arg_name && current_tool) { + if (arg_count > 0) { + current_tool->arguments += ","; + } + current_tool->arguments += json(trim_trailing_space(node.text)).dump() + ":"; + ++arg_count; + } + + if (is_arg_string && current_tool) { + // Serialize to JSON, but exclude the end quote + std::string dumped = json(node.text).dump(); + current_tool->arguments += dumped.substr(0, dumped.size() - 1); + needs_closing_quote = true; + } + + if (is_arg_close && current_tool) { + if (needs_closing_quote) { + current_tool->arguments += "\""; + } + } + + if (is_arg_json && current_tool) { + current_tool->arguments += std::string(trim_trailing_space(node.text)); + } + + if (is_tool_close && current_tool) { + current_tool->arguments += "}"; + } +} diff --git a/common/chat-peg-parser.h b/common/chat-peg-parser.h new file mode 100644 index 00000000000..b84cbed2069 --- /dev/null +++ b/common/chat-peg-parser.h @@ -0,0 +1,105 @@ +#pragma once + +#include "chat.h" +#include "peg-parser.h" + +class common_chat_peg_builder : public common_peg_parser_builder { + public: + static constexpr const char * REASONING_BLOCK = "reasoning-block"; + static constexpr const char * REASONING = "reasoning"; + static constexpr const char * CONTENT = "content"; + + common_peg_parser reasoning_block(const common_peg_parser & p) { return tag(REASONING_BLOCK, p); } + common_peg_parser reasoning(const common_peg_parser & p) { return tag(REASONING, p); } + common_peg_parser content(const common_peg_parser & p) { return tag(CONTENT, p); } +}; + +inline common_peg_arena build_chat_peg_parser(const std::function & fn) { + common_chat_peg_builder builder; + builder.set_root(fn(builder)); + return builder.build(); +} + +class common_chat_peg_mapper { + public: + common_chat_msg & result; + + common_chat_peg_mapper(common_chat_msg & msg) : result(msg) {} + + virtual void from_ast(const common_peg_ast_arena & arena, const common_peg_parse_result & result); + virtual void map(const common_peg_ast_node & node); +}; + +class common_chat_peg_native_builder : public common_chat_peg_builder { + public: + static constexpr const char * TOOL = "tool"; + static constexpr const char * TOOL_OPEN = "tool-open"; + static constexpr const char * TOOL_CLOSE = "tool-close"; + static constexpr const char * TOOL_ID = "tool-id"; + static constexpr const char * TOOL_NAME = "tool-name"; + static constexpr const char * TOOL_ARGS = "tool-args"; + + common_peg_parser tool(const common_peg_parser & p) { return tag(TOOL, p); } + common_peg_parser tool_open(const common_peg_parser & p) { return atomic(tag(TOOL_OPEN, p)); } + common_peg_parser tool_close(const common_peg_parser & p) { return atomic(tag(TOOL_CLOSE, p)); } + common_peg_parser tool_id(const common_peg_parser & p) { return atomic(tag(TOOL_ID, p)); } + common_peg_parser tool_name(const common_peg_parser & p) { return atomic(tag(TOOL_NAME, p)); } + common_peg_parser tool_args(const common_peg_parser & p) { return tag(TOOL_ARGS, p); } +}; + +class common_chat_peg_native_mapper : public common_chat_peg_mapper { + common_chat_tool_call * current_tool; + + public: + common_chat_peg_native_mapper(common_chat_msg & msg) : common_chat_peg_mapper(msg) {} + + void map(const common_peg_ast_node & node) override; +}; + +inline common_peg_arena build_chat_peg_native_parser(const std::function & fn) { + common_chat_peg_native_builder builder; + builder.set_root(fn(builder)); + return builder.build(); +} + +class common_chat_peg_constructed_builder : public common_chat_peg_builder { + public: + static constexpr const char * TOOL = "tool"; + static constexpr const char * TOOL_OPEN = "tool-open"; + static constexpr const char * TOOL_CLOSE = "tool-close"; + static constexpr const char * TOOL_NAME = "tool-name"; + static constexpr const char * TOOL_ARG = "tool-arg"; + static constexpr const char * TOOL_ARG_OPEN = "tool-arg-open"; + static constexpr const char * TOOL_ARG_CLOSE = "tool-arg-close"; + static constexpr const char * TOOL_ARG_NAME = "tool-arg-name"; + static constexpr const char * TOOL_ARG_STRING_VALUE = "tool-arg-string-value"; + static constexpr const char * TOOL_ARG_JSON_VALUE = "tool-arg-json-value"; + + common_peg_parser tool(const common_peg_parser & p) { return tag(TOOL, p); } + common_peg_parser tool_open(const common_peg_parser & p) { return atomic(tag(TOOL_OPEN, p)); } + common_peg_parser tool_close(const common_peg_parser & p) { return atomic(tag(TOOL_CLOSE, p)); } + common_peg_parser tool_name(const common_peg_parser & p) { return atomic(tag(TOOL_NAME, p)); } + common_peg_parser tool_arg(const common_peg_parser & p) { return tag(TOOL_ARG, p); } + common_peg_parser tool_arg_open(const common_peg_parser & p) { return atomic(tag(TOOL_ARG_OPEN, p)); } + common_peg_parser tool_arg_close(const common_peg_parser & p) { return atomic(tag(TOOL_ARG_CLOSE, p)); } + common_peg_parser tool_arg_name(const common_peg_parser & p) { return atomic(tag(TOOL_ARG_NAME, p)); } + common_peg_parser tool_arg_string_value(const common_peg_parser & p) { return tag(TOOL_ARG_STRING_VALUE, p); } + common_peg_parser tool_arg_json_value(const common_peg_parser & p) { return tag(TOOL_ARG_JSON_VALUE, p); } +}; + +class common_chat_peg_constructed_mapper : public common_chat_peg_mapper { + common_chat_tool_call * current_tool; + int arg_count = 0; + bool needs_closing_quote = false; + + public: + common_chat_peg_constructed_mapper(common_chat_msg & msg) : common_chat_peg_mapper(msg) {} + + void map(const common_peg_ast_node & node) override; +}; + +inline common_peg_arena build_chat_peg_constructed_parser(const std::function & fn) { + common_chat_peg_constructed_builder builder; + builder.set_root(fn(builder)); + return builder.build(); +} diff --git a/common/chat.cpp b/common/chat.cpp index b4a0f985e2e..41a5bb42d51 100644 --- a/common/chat.cpp +++ b/common/chat.cpp @@ -85,29 +85,36 @@ json common_chat_msg::to_json_oaicompat() const return message; } -std::vector common_chat_msg_diff::compute_diffs(const common_chat_msg & previous_msg, const common_chat_msg & new_msg) { +std::vector common_chat_msg_diff::compute_diffs(const common_chat_msg & msg_prv, const common_chat_msg & msg_new) { std::vector diffs; - if (previous_msg.reasoning_content != new_msg.reasoning_content) { + if (msg_new.tool_calls.size() > msg_prv.tool_calls.size()) { + diffs.reserve(msg_new.tool_calls.size() - msg_prv.tool_calls.size() + 3); + } else { + diffs.reserve(3); + } + + // TODO: these can become expensive for long messages - how to optimize? + if (msg_prv.reasoning_content != msg_new.reasoning_content) { auto & diff = diffs.emplace_back(); - diff.reasoning_content_delta = string_diff(previous_msg.reasoning_content, new_msg.reasoning_content); + diff.reasoning_content_delta = string_diff(msg_prv.reasoning_content, msg_new.reasoning_content); } - if (previous_msg.content != new_msg.content) { + if (msg_prv.content != msg_new.content) { auto & diff = diffs.emplace_back(); - diff.content_delta = string_diff(previous_msg.content, new_msg.content); + diff.content_delta = string_diff(msg_prv.content, msg_new.content); } - if (new_msg.tool_calls.size() < previous_msg.tool_calls.size()) { + if (msg_new.tool_calls.size() < msg_prv.tool_calls.size()) { throw std::runtime_error("Invalid diff: now finding less tool calls!"); } - if (!previous_msg.tool_calls.empty()) { - auto idx = previous_msg.tool_calls.size() - 1; - const auto & pref = previous_msg.tool_calls[idx]; - const auto & newf = new_msg.tool_calls[idx]; + if (!msg_prv.tool_calls.empty()) { + const auto idx = msg_prv.tool_calls.size() - 1; + const auto & pref = msg_prv.tool_calls[idx]; + const auto & newf = msg_new.tool_calls[idx]; if (pref.name != newf.name) { throw std::runtime_error("Invalid diff: tool call mismatch!"); } - auto args_diff = string_diff(pref.arguments, newf.arguments); + const auto args_diff = string_diff(pref.arguments, newf.arguments); if (!args_diff.empty() || pref.id != newf.id) { auto & diff = diffs.emplace_back(); diff.tool_call_index = idx; @@ -118,11 +125,12 @@ std::vector common_chat_msg_diff::compute_diffs(const comm diff.tool_call_delta.arguments = args_diff; } } - for (size_t idx = previous_msg.tool_calls.size(); idx < new_msg.tool_calls.size(); ++idx) { + for (size_t idx = msg_prv.tool_calls.size(); idx < msg_new.tool_calls.size(); ++idx) { auto & diff = diffs.emplace_back(); diff.tool_call_index = idx; - diff.tool_call_delta = new_msg.tool_calls[idx]; + diff.tool_call_delta = msg_new.tool_calls[idx]; } + return diffs; } @@ -163,7 +171,7 @@ common_chat_tool_choice common_chat_tool_choice_parse_oaicompat(const std::strin if (tool_choice == "required") { return COMMON_CHAT_TOOL_CHOICE_REQUIRED; } - throw std::runtime_error("Invalid tool_choice: " + tool_choice); + throw std::invalid_argument("Invalid tool_choice: " + tool_choice); } bool common_chat_templates_support_enable_thinking(const common_chat_templates * chat_templates) { @@ -186,17 +194,17 @@ std::vector common_chat_msgs_parse_oaicompat(const json & messa try { if (!messages.is_array()) { - throw std::runtime_error("Expected 'messages' to be an array, got " + messages.dump()); + throw std::invalid_argument("Expected 'messages' to be an array, got " + messages.dump()); } for (const auto & message : messages) { if (!message.is_object()) { - throw std::runtime_error("Expected 'message' to be an object, got " + message.dump()); + throw std::invalid_argument("Expected 'message' to be an object, got " + message.dump()); } common_chat_msg msg; if (!message.contains("role")) { - throw std::runtime_error("Missing 'role' in message: " + message.dump()); + throw std::invalid_argument("Missing 'role' in message: " + message.dump()); } msg.role = message.at("role"); @@ -209,11 +217,11 @@ std::vector common_chat_msgs_parse_oaicompat(const json & messa } else if (content.is_array()) { for (const auto & part : content) { if (!part.contains("type")) { - throw std::runtime_error("Missing content part type: " + part.dump()); + throw std::invalid_argument("Missing content part type: " + part.dump()); } const auto & type = part.at("type"); if (type != "text") { - throw std::runtime_error("Unsupported content part type: " + type.dump()); + throw std::invalid_argument("Unsupported content part type: " + type.dump()); } common_chat_msg_content_part msg_part; msg_part.type = type; @@ -221,25 +229,25 @@ std::vector common_chat_msgs_parse_oaicompat(const json & messa msg.content_parts.push_back(msg_part); } } else if (!content.is_null()) { - throw std::runtime_error("Invalid 'content' type: expected string or array, got " + content.dump() + " (ref: https://github.com/ggml-org/llama.cpp/issues/8367)"); + throw std::invalid_argument("Invalid 'content' type: expected string or array, got " + content.dump() + " (ref: https://github.com/ggml-org/llama.cpp/issues/8367)"); } } if (has_tool_calls) { for (const auto & tool_call : message.at("tool_calls")) { common_chat_tool_call tc; if (!tool_call.contains("type")) { - throw std::runtime_error("Missing tool call type: " + tool_call.dump()); + throw std::invalid_argument("Missing tool call type: " + tool_call.dump()); } const auto & type = tool_call.at("type"); if (type != "function") { - throw std::runtime_error("Unsupported tool call type: " + tool_call.dump()); + throw std::invalid_argument("Unsupported tool call type: " + tool_call.dump()); } if (!tool_call.contains("function")) { - throw std::runtime_error("Missing tool call function: " + tool_call.dump()); + throw std::invalid_argument("Missing tool call function: " + tool_call.dump()); } const auto & fc = tool_call.at("function"); if (!fc.contains("name")) { - throw std::runtime_error("Missing tool call name: " + tool_call.dump()); + throw std::invalid_argument("Missing tool call name: " + tool_call.dump()); } tc.name = fc.at("name"); tc.arguments = fc.at("arguments"); @@ -250,7 +258,7 @@ std::vector common_chat_msgs_parse_oaicompat(const json & messa } } if (!has_content && !has_tool_calls) { - throw std::runtime_error("Expected 'content' or 'tool_calls' (ref: https://github.com/ggml-org/llama.cpp/issues/8367 & https://github.com/ggml-org/llama.cpp/issues/12279)"); + throw std::invalid_argument("Expected 'content' or 'tool_calls' (ref: https://github.com/ggml-org/llama.cpp/issues/8367 & https://github.com/ggml-org/llama.cpp/issues/12279)"); } if (message.contains("reasoning_content")) { msg.reasoning_content = message.at("reasoning_content"); @@ -353,18 +361,18 @@ std::vector common_chat_tools_parse_oaicompat(const json & too try { if (!tools.is_null()) { if (!tools.is_array()) { - throw std::runtime_error("Expected 'tools' to be an array, got " + tools.dump()); + throw std::invalid_argument("Expected 'tools' to be an array, got " + tools.dump()); } for (const auto & tool : tools) { if (!tool.contains("type")) { - throw std::runtime_error("Missing tool type: " + tool.dump()); + throw std::invalid_argument("Missing tool type: " + tool.dump()); } const auto & type = tool.at("type"); if (!type.is_string() || type != "function") { - throw std::runtime_error("Unsupported tool type: " + tool.dump()); + throw std::invalid_argument("Unsupported tool type: " + tool.dump()); } if (!tool.contains("function")) { - throw std::runtime_error("Missing tool function: " + tool.dump()); + throw std::invalid_argument("Missing tool function: " + tool.dump()); } const auto & function = tool.at("function"); @@ -649,6 +657,9 @@ const char * common_chat_format_name(common_chat_format format) { case COMMON_CHAT_FORMAT_QWEN3_CODER_XML: return "Qwen3 Coder"; case COMMON_CHAT_FORMAT_APRIEL_1_5: return "Apriel 1.5"; case COMMON_CHAT_FORMAT_XIAOMI_MIMO: return "Xiaomi MiMo"; + case COMMON_CHAT_FORMAT_PEG_SIMPLE: return "peg-simple"; + case COMMON_CHAT_FORMAT_PEG_NATIVE: return "peg-native"; + case COMMON_CHAT_FORMAT_PEG_CONSTRUCTED: return "peg-constructed"; default: throw std::runtime_error("Unknown chat format"); } diff --git a/common/chat.h b/common/chat.h index 754c411e237..6085510a402 100644 --- a/common/chat.h +++ b/common/chat.h @@ -3,6 +3,7 @@ #pragma once #include "common.h" +#include "peg-parser.h" #include #include #include @@ -76,7 +77,7 @@ struct common_chat_msg_diff { size_t tool_call_index = std::string::npos; common_chat_tool_call tool_call_delta; - static std::vector compute_diffs(const common_chat_msg & previous_msg, const common_chat_msg & new_msg); + static std::vector compute_diffs(const common_chat_msg & msg_prv, const common_chat_msg & msg_new); bool operator==(const common_chat_msg_diff & other) const { return content_delta == other.content_delta @@ -124,6 +125,11 @@ enum common_chat_format { COMMON_CHAT_FORMAT_APRIEL_1_5, COMMON_CHAT_FORMAT_XIAOMI_MIMO, + // These are intended to be parsed by the PEG parser + COMMON_CHAT_FORMAT_PEG_SIMPLE, + COMMON_CHAT_FORMAT_PEG_NATIVE, + COMMON_CHAT_FORMAT_PEG_CONSTRUCTED, + COMMON_CHAT_FORMAT_COUNT, // Not a format, just the # formats }; @@ -154,6 +160,7 @@ struct common_chat_params { std::vector grammar_triggers; std::vector preserved_tokens; std::vector additional_stops; + std::string parser; }; struct common_chat_syntax { @@ -163,6 +170,7 @@ struct common_chat_syntax { bool reasoning_in_content = false; bool thinking_forced_open = false; bool parse_tool_calls = true; + common_peg_arena parser = {}; }; // Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid @@ -206,6 +214,7 @@ const char* common_chat_format_name(common_chat_format format); const char* common_reasoning_format_name(common_reasoning_format format); common_reasoning_format common_reasoning_format_from_name(const std::string & format); common_chat_msg common_chat_parse(const std::string & input, bool is_partial, const common_chat_syntax & syntax); +common_chat_msg common_chat_peg_parse(const common_peg_arena & parser, const std::string & input, bool is_partial, const common_chat_syntax & syntax); common_chat_tool_choice common_chat_tool_choice_parse_oaicompat(const std::string & tool_choice); diff --git a/common/common.cpp b/common/common.cpp index 10001f54697..f07af1d8625 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -694,7 +694,7 @@ bool string_parse_kv_override(const char * data, std::vector= 0xD800 && c <= 0xDFFF) // UTF-16 surrogate pairs || c == 0xFFFD // Replacement Character (UTF-8) || c == 0xFEFF // Byte Order Mark (BOM) - || c == '/' || c == '\\' || c == ':' || c == '*' // Illegal characters + || c == ':' || c == '*' // Illegal characters || c == '?' || c == '"' || c == '<' || c == '>' || c == '|') { return false; } + if (!allow_subdirs && (c == '/' || c == '\\')) { + // Subdirectories not allowed, reject path separators + return false; + } } // Reject any leading or trailing ' ', or any trailing '.', these are stripped on Windows and will cause a different filename @@ -859,6 +863,11 @@ bool fs_create_directory_with_parents(const std::string & path) { #endif // _WIN32 } +bool fs_is_directory(const std::string & path) { + std::filesystem::path dir(path); + return std::filesystem::exists(dir) && std::filesystem::is_directory(dir); +} + std::string fs_get_cache_directory() { std::string cache_directory = ""; auto ensure_trailing_slash = [](std::string p) { @@ -893,6 +902,8 @@ std::string fs_get_cache_directory() { cache_directory = std::getenv("HOME") + std::string("/Library/Caches/"); #elif defined(_WIN32) cache_directory = std::getenv("LOCALAPPDATA"); +#elif defined(__EMSCRIPTEN__) + GGML_ABORT("not implemented on this platform"); #else # error Unknown architecture #endif diff --git a/common/common.h b/common/common.h index cdca5e26a23..a24a9fd2951 100644 --- a/common/common.h +++ b/common/common.h @@ -12,6 +12,10 @@ #include #include +#if defined(_WIN32) && !defined(_WIN32_WINNT) +#define _WIN32_WINNT 0x0A00 +#endif + #ifdef _WIN32 #define DIRECTORY_SEPARATOR '\\' #else @@ -432,7 +436,7 @@ struct common_params { std::vector image; // path to image file(s) int image_min_tokens = -1; int image_max_tokens = -1; - + int32_t num_crops = -1; // For Phi3-5 Vision, number of max local crops // finetune struct lr_opt lr; enum ggml_opt_optimizer_type optimizer = GGML_OPT_OPTIMIZER_TYPE_ADAMW; @@ -485,6 +489,7 @@ struct common_params { bool log_json = false; std::string slot_save_path; + std::string media_path; // path to directory for loading media files float slot_prompt_similarity = 0.1f; @@ -635,8 +640,9 @@ std::string string_from(const struct llama_context * ctx, const struct llama_bat // Filesystem utils // -bool fs_validate_filename(const std::string & filename); +bool fs_validate_filename(const std::string & filename, bool allow_subdirs = false); bool fs_create_directory_with_parents(const std::string & path); +bool fs_is_directory(const std::string & path); std::string fs_get_cache_directory(); std::string fs_get_cache_file(const std::string & filename); diff --git a/common/download.cpp b/common/download.cpp index e6ce7f11f54..ab68c53b43d 100644 --- a/common/download.cpp +++ b/common/download.cpp @@ -24,6 +24,7 @@ #include "http.h" #endif +#ifndef __EMSCRIPTEN__ #ifdef __linux__ #include #elif defined(_WIN32) @@ -35,6 +36,8 @@ #else #include #endif +#endif + #define LLAMA_MAX_URL_LENGTH 2084 // Maximum URL Length in Chrome: 2083 // isatty diff --git a/common/json-schema-to-grammar.cpp b/common/json-schema-to-grammar.cpp index c8421e1e826..c3b4e5d9dc7 100644 --- a/common/json-schema-to-grammar.cpp +++ b/common/json-schema-to-grammar.cpp @@ -974,7 +974,7 @@ class SchemaConverter { void check_errors() { if (!_errors.empty()) { - throw std::runtime_error("JSON schema conversion failed:\n" + string_join(_errors, "\n")); + throw std::invalid_argument("JSON schema conversion failed:\n" + string_join(_errors, "\n")); } if (!_warnings.empty()) { fprintf(stderr, "WARNING: JSON schema conversion was incomplete: %s\n", string_join(_warnings, "; ").c_str()); diff --git a/common/peg-parser.cpp b/common/peg-parser.cpp new file mode 100644 index 00000000000..dec99e1820f --- /dev/null +++ b/common/peg-parser.cpp @@ -0,0 +1,1712 @@ +#include "common.h" +#include "peg-parser.h" +#include "json-schema-to-grammar.h" +#include "unicode.h" + +#include + +#include +#include +#include +#include +#include +#include +#include + +// Trick to catch missing branches +template +inline constexpr bool is_always_false_v = false; + +const char * common_peg_parse_result_type_name(common_peg_parse_result_type type) { + switch (type) { + case COMMON_PEG_PARSE_RESULT_FAIL: return "fail"; + case COMMON_PEG_PARSE_RESULT_SUCCESS: return "success"; + case COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT: return "need_more_input"; + default: return "unknown"; + } +} + +static bool is_hex_digit(const char c) { + return (c >= '0' && c <= '9') || (c >= 'a' && c <= 'f') || (c >= 'A' && c <= 'F'); +} + +// Trie for matching multiple literals. +// This is used in common_peg_until_parser and to build a GBNF exclusion grammar +struct trie { + struct node { + size_t depth = 0; + std::map children; + bool is_word; + }; + + std::vector nodes; + + trie(const std::vector & words) { + create_node(); // root node + for (const auto & w : words) { + insert(w); + } + } + + enum match_result { NO_MATCH, PARTIAL_MATCH, COMPLETE_MATCH }; + + // Check if a delimiter starts at the given position + match_result check_at(std::string_view sv, size_t start_pos) const { + size_t current = 0; // Start at root + size_t pos = start_pos; + + while (pos < sv.size()) { + auto it = nodes[current].children.find(sv[pos]); + if (it == nodes[current].children.end()) { + // Can't continue matching + return match_result{match_result::NO_MATCH}; + } + + current = it->second; + pos++; + + // Check if we've matched a complete word + if (nodes[current].is_word) { + return match_result{match_result::COMPLETE_MATCH}; + } + } + + // Reached end of input while still in the trie (not at root) + if (current != 0) { + // We're in the middle of a potential match + return match_result{match_result::PARTIAL_MATCH}; + } + + // Reached end at root (no match) + return match_result{match_result::NO_MATCH}; + } + + struct prefix_and_next { + std::string prefix; + std::string next_chars; + }; + + std::vector collect_prefix_and_next() { + std::string prefix; + std::vector result; + collect_prefix_and_next(0, prefix, result); + return result; + } + + private: + void collect_prefix_and_next(size_t index, std::string & prefix, std::vector & out) { + if (!nodes[index].is_word) { + if (!nodes[index].children.empty()) { + std::string chars; + chars.reserve(nodes[index].children.size()); + for (const auto & p : nodes[index].children) { + chars.push_back(p.first); + } + out.emplace_back(prefix_and_next{prefix, chars}); + } + } + + for (const auto & p : nodes[index].children) { + unsigned char ch = p.first; + auto child = p.second; + prefix.push_back(ch); + collect_prefix_and_next(child, prefix, out); + prefix.pop_back(); + } + } + + size_t create_node() { + size_t index = nodes.size(); + nodes.emplace_back(); + return index; + } + + void insert(const std::string & word) { + size_t current = 0; + for (unsigned char ch : word) { + auto it = nodes[current].children.find(ch); + if (it == nodes[current].children.end()) { + size_t child = create_node(); + nodes[child].depth = nodes[current].depth + 1; + nodes[current].children[ch] = child; + current = child; + } else { + current = it->second; + } + } + nodes[current].is_word = true; + } +}; + +static std::pair parse_hex_escape(const std::string & str, size_t pos, int hex_count) { + if (pos + hex_count > str.length()) { + return {0, 0}; + } + + uint32_t value = 0; + for (int i = 0; i < hex_count; i++) { + char c = str[pos + i]; + if (!is_hex_digit(c)) { + return {0, 0}; + } + value <<= 4; + if ('a' <= c && c <= 'f') { + value += c - 'a' + 10; + } else if ('A' <= c && c <= 'F') { + value += c - 'A' + 10; + } else if ('0' <= c && c <= '9') { + value += c - '0'; + } else { + break; + } + } + return {value, static_cast(hex_count)}; +} + +static std::pair parse_char_class_char(const std::string & content, size_t pos) { + if (content[pos] == '\\' && pos + 1 < content.length()) { + switch (content[pos + 1]) { + case 'x': { + auto result = parse_hex_escape(content, pos + 2, 2); + if (result.second > 0) { + return {result.first, 2 + result.second}; + } + // Invalid escape, treat as literal 'x' + return {static_cast('x'), 2}; + } + case 'u': { + auto result = parse_hex_escape(content, pos + 2, 4); + if (result.second > 0) { + return {result.first, 2 + result.second}; + } + // Invalid escape, treat as literal 'u' + return {static_cast('u'), 2}; + } + case 'U': { + auto result = parse_hex_escape(content, pos + 2, 8); + if (result.second > 0) { + return {result.first, 2 + result.second}; + } + // Invalid escape, treat as literal 'U' + return {static_cast('U'), 2}; + } + case 'n': return {'\n', 2}; + case 't': return {'\t', 2}; + case 'r': return {'\r', 2}; + case '\\': return {'\\', 2}; + case ']': return {']', 2}; + case '[': return {'[', 2}; + default: return {static_cast(content[pos + 1]), 2}; + } + } + + // Regular character - return as codepoint + return {static_cast(static_cast(content[pos])), 1}; +} + +static std::pair, bool> parse_char_classes(const std::string & classes) { + std::vector ranges; + bool negated = false; + + std::string content = classes; + if (content.front() == '[') { + content = content.substr(1); + } + + if (content.back() == ']') { + content.pop_back(); + } + + // Check for negation + if (!content.empty() && content.front() == '^') { + negated = true; + content = content.substr(1); + } + + size_t i = 0; + while (i < content.length()) { + auto [start, start_len] = parse_char_class_char(content, i); + i += start_len; + + if (i + 1 < content.length() && content[i] == '-') { + // Range detected + auto [end, end_len] = parse_char_class_char(content, i + 1); + ranges.push_back(common_peg_chars_parser::char_range{start, end}); + i += 1 + end_len; + } else { + ranges.push_back(common_peg_chars_parser::char_range{start, start}); + } + } + + return {ranges, negated}; +} + +void common_peg_ast_arena::visit(common_peg_ast_id id, const common_peg_ast_visitor & visitor) const { + if (id == COMMON_PEG_INVALID_AST_ID) { + return; + } + const auto & node = get(id); + visitor(node); + for (const auto & child : node.children) { + visit(child, visitor); + } +} + +void common_peg_ast_arena::visit(const common_peg_parse_result & result, const common_peg_ast_visitor & visitor) const { + for (const auto & node : result.nodes) { + visit(node, visitor); + } +} + +struct parser_executor; + +common_peg_parser_id common_peg_arena::add_parser(common_peg_parser_variant parser) { + common_peg_parser_id id = parsers_.size(); + parsers_.push_back(std::move(parser)); + return id; +} + +void common_peg_arena::add_rule(const std::string & name, common_peg_parser_id id) { + rules_[name] = id; +} + +common_peg_parser_id common_peg_arena::get_rule(const std::string & name) const { + auto it = rules_.find(name); + if (it == rules_.end()) { + throw std::runtime_error("Rule not found: " + name); + } + return it->second; +} + +struct parser_executor { + const common_peg_arena & arena; + common_peg_parse_context & ctx; + size_t start_pos; + + parser_executor(const common_peg_arena & arena, common_peg_parse_context & ctx, size_t start) + : arena(arena), ctx(ctx), start_pos(start) {} + + common_peg_parse_result operator()(const common_peg_epsilon_parser & /* p */) const { + return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_SUCCESS, start_pos); + } + + common_peg_parse_result operator()(const common_peg_start_parser & /* p */) const { + return common_peg_parse_result( + start_pos == 0 ? COMMON_PEG_PARSE_RESULT_SUCCESS : COMMON_PEG_PARSE_RESULT_FAIL, + start_pos + ); + } + + common_peg_parse_result operator()(const common_peg_end_parser & /* p */) const { + return common_peg_parse_result( + start_pos >= ctx.input.size() ? COMMON_PEG_PARSE_RESULT_SUCCESS : COMMON_PEG_PARSE_RESULT_FAIL, + start_pos + ); + } + + common_peg_parse_result operator()(const common_peg_literal_parser & p) { + auto pos = start_pos; + for (auto i = 0u; i < p.literal.size(); ++i) { + if (pos >= ctx.input.size()) { + if (!ctx.is_partial) { + return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_FAIL, start_pos); + } + return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT, start_pos, pos); + } + if (ctx.input[pos] != p.literal[i]) { + return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_FAIL, start_pos); + } + ++pos; + } + + return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_SUCCESS, start_pos, pos); + } + + common_peg_parse_result operator()(const common_peg_sequence_parser & p) { + auto pos = start_pos; + std::vector nodes; + + for (const auto & child_id : p.children) { + auto result = arena.parse(child_id, ctx, pos); + if (result.fail()) { + return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_FAIL, start_pos, result.end); + } + + if (!result.nodes.empty()) { + nodes.insert(nodes.end(), result.nodes.begin(), result.nodes.end()); + } + + if (result.need_more_input()) { + return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT, start_pos, result.end, std::move(nodes)); + } + + pos = result.end; + } + + return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_SUCCESS, start_pos, pos, std::move(nodes)); + } + + common_peg_parse_result operator()(const common_peg_choice_parser & p) { + auto pos = start_pos; + for (const auto & child_id : p.children) { + auto result = arena.parse(child_id, ctx, pos); + if (!result.fail()) { + return result; + } + } + + return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_FAIL, start_pos); + } + + common_peg_parse_result operator()(const common_peg_repetition_parser & p) { + auto pos = start_pos; + int match_count = 0; + std::vector nodes; + + // Try to match up to max_count times (or unlimited if max_count is -1) + while (p.max_count == -1 || match_count < p.max_count) { + if (pos >= ctx.input.size()) { + break; + } + + auto result = arena.parse(p.child, ctx, pos); + + if (result.success()) { + // Prevent infinite loop on empty matches + if (result.end == pos) { + break; + } + + if (!result.nodes.empty()) { + nodes.insert(nodes.end(), result.nodes.begin(), result.nodes.end()); + } + + pos = result.end; + match_count++; + continue; + } + + if (result.need_more_input()) { + if (!result.nodes.empty()) { + nodes.insert(nodes.end(), result.nodes.begin(), result.nodes.end()); + } + + return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT, start_pos, result.end, std::move(nodes)); + } + + // Child failed - stop trying + break; + } + + // Check if we got enough matches + if (p.min_count > 0 && match_count < p.min_count) { + if (pos >= ctx.input.size() && ctx.is_partial) { + return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT, start_pos, pos, std::move(nodes)); + } + return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_FAIL, start_pos, pos); + } + + return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_SUCCESS, start_pos, pos, std::move(nodes)); + } + + common_peg_parse_result operator()(const common_peg_and_parser & p) { + auto result = arena.parse(p.child, ctx, start_pos); + // Pass result but don't consume input + return common_peg_parse_result(result.type, start_pos); + } + + common_peg_parse_result operator()(const common_peg_not_parser & p) { + auto result = arena.parse(p.child, ctx, start_pos); + + if (result.success()) { + // Fail if the underlying parser matches + return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_FAIL, start_pos); + } + + if (result.need_more_input()) { + // Propagate - need to know what child would match before negating + return result; + } + + // Child failed, so negation succeeds + return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_SUCCESS, start_pos); + } + + common_peg_parse_result operator()(const common_peg_any_parser & /* p */) const { + // Parse a single UTF-8 codepoint (not just a single byte) + auto result = parse_utf8_codepoint(ctx.input, start_pos); + + if (result.status == utf8_parse_result::INCOMPLETE) { + if (!ctx.is_partial) { + return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_FAIL, start_pos); + } + return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT, start_pos); + } + if (result.status == utf8_parse_result::INVALID) { + return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_FAIL, start_pos); + } + return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_SUCCESS, start_pos, start_pos + result.bytes_consumed); + } + + common_peg_parse_result operator()(const common_peg_space_parser & /* p */) { + auto pos = start_pos; + while (pos < ctx.input.size()) { + auto c = static_cast(ctx.input[pos]); + if (std::isspace(c)) { + ++pos; + } else { + break; + } + } + + return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_SUCCESS, start_pos, pos); + } + + common_peg_parse_result operator()(const common_peg_chars_parser & p) const { + auto pos = start_pos; + int match_count = 0; + + // Try to match up to max_count times (or unlimited if max_count is -1) + while (p.max_count == -1 || match_count < p.max_count) { + auto result = parse_utf8_codepoint(ctx.input, pos); + + if (result.status == utf8_parse_result::INCOMPLETE) { + if (match_count >= p.min_count) { + // We have enough matches, succeed with what we have + return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_SUCCESS, start_pos, pos); + } + // Not enough matches yet + if (!ctx.is_partial) { + return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_FAIL, start_pos); + } + return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT, start_pos, pos); + } + + if (result.status == utf8_parse_result::INVALID) { + // Malformed UTF-8 in input + if (match_count >= p.min_count) { + // We have enough matches, succeed up to here + return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_SUCCESS, start_pos, pos); + } + // Not enough matches, fail + return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_FAIL, start_pos); + } + + // Check if this codepoint matches our character class + bool matches = false; + for (const auto & range : p.ranges) { + if (range.contains(result.codepoint)) { + matches = true; + break; + } + } + + // If negated, invert the match result + if (p.negated) { + matches = !matches; + } + + if (matches) { + pos += result.bytes_consumed; + ++match_count; + } else { + // Character doesn't match, stop matching + break; + } + } + + // Check if we got enough matches + if (match_count < p.min_count) { + if (pos >= ctx.input.size() && ctx.is_partial) { + return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT, start_pos, pos); + } + return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_FAIL, start_pos, pos); + } + + return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_SUCCESS, start_pos, pos); + } + + static common_peg_parse_result handle_escape_sequence(common_peg_parse_context & ctx, size_t start, size_t & pos) { + ++pos; // consume '\' + if (pos >= ctx.input.size()) { + if (!ctx.is_partial) { + return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_FAIL, start); + } + return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT, start, pos); + } + + switch (ctx.input[pos]) { + case '"': + case '\\': + case '/': + case 'b': + case 'f': + case 'n': + case 'r': + case 't': + ++pos; + return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_SUCCESS, start, pos); + case 'u': + return handle_unicode_escape(ctx, start, pos); + default: + // Invalid escape sequence + return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_FAIL, start); + } + } + + static common_peg_parse_result handle_unicode_escape(common_peg_parse_context & ctx, size_t start, size_t & pos) { + ++pos; // consume 'u' + for (int i = 0; i < 4; ++i) { + if (pos >= ctx.input.size()) { + if (!ctx.is_partial) { + return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_FAIL, start); + } + return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT, start, pos); + } + if (!is_hex_digit(ctx.input[pos])) { + return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_FAIL, start); + } + ++pos; + } + return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_SUCCESS, start, pos); + } + + common_peg_parse_result operator()(const common_peg_json_string_parser & /* p */) { + auto pos = start_pos; + + // Parse string content (without quotes) + while (pos < ctx.input.size()) { + char c = ctx.input[pos]; + + if (c == '"') { + // Found closing quote - success (don't consume it) + return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_SUCCESS, start_pos, pos); + } + + if (c == '\\') { + auto result = handle_escape_sequence(ctx, start_pos, pos); + if (!result.success()) { + return result; + } + } else { + auto utf8_result = parse_utf8_codepoint(ctx.input, pos); + + if (utf8_result.status == utf8_parse_result::INCOMPLETE) { + if (!ctx.is_partial) { + return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_FAIL, start_pos); + } + return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT, start_pos, pos); + } + + if (utf8_result.status == utf8_parse_result::INVALID) { + return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_FAIL, start_pos); + } + + pos += utf8_result.bytes_consumed; + } + } + + // Reached end without finding closing quote + if (!ctx.is_partial) { + return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_FAIL, start_pos, pos); + } + return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT, start_pos, pos); + } + + common_peg_parse_result operator()(const common_peg_until_parser & p) const { + trie matcher(p.delimiters); + + // Scan input and check for delimiters + size_t pos = start_pos; + size_t last_valid_pos = start_pos; + + while (pos < ctx.input.size()) { + auto utf8_result = parse_utf8_codepoint(ctx.input, pos); + + if (utf8_result.status == utf8_parse_result::INCOMPLETE) { + // Incomplete UTF-8 sequence + if (!ctx.is_partial) { + // Input is complete but UTF-8 is incomplete = malformed + return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_FAIL, start_pos); + } + // Return what we have so far (before incomplete sequence) + return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT, start_pos, last_valid_pos); + } + + if (utf8_result.status == utf8_parse_result::INVALID) { + // Malformed UTF-8 + return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_FAIL, start_pos); + } + + // Check if a delimiter starts at this position + auto match = matcher.check_at(ctx.input, pos); + + if (match == trie::COMPLETE_MATCH) { + // Found a complete delimiter, return everything before it + return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_SUCCESS, start_pos, pos); + } + + if (match == trie::PARTIAL_MATCH) { + // Found a partial match extending to end of input, return everything before it + return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_SUCCESS, start_pos, pos); + } + + pos += utf8_result.bytes_consumed; + last_valid_pos = pos; + } + + if (last_valid_pos == ctx.input.size() && ctx.is_partial) { + // Reached the end of a partial stream, there might still be more input that we need to consume. + return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT, start_pos, last_valid_pos); + } + return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_SUCCESS, start_pos, last_valid_pos); + } + + common_peg_parse_result operator()(const common_peg_schema_parser & p) { + return arena.parse(p.child, ctx, start_pos); + } + + common_peg_parse_result operator()(const common_peg_rule_parser & p) { + // Parse the child + auto result = arena.parse(p.child, ctx, start_pos); + + if (!result.fail()) { + std::string_view text; + if (result.start < ctx.input.size()) { + text = std::string_view(ctx.input).substr(result.start, result.end - result.start); + } + + auto node_id = ctx.ast.add_node( + p.name, + "", + result.start, + result.end, + text, + std::move(result.nodes), + result.need_more_input() + ); + + return common_peg_parse_result(result.type, result.start, result.end, { node_id }); + } + + return result; + } + + common_peg_parse_result operator()(const common_peg_tag_parser & p) { + // Parse the child + auto result = arena.parse(p.child, ctx, start_pos); + + if (!result.fail()) { + std::string_view text; + if (result.start < ctx.input.size()) { + text = std::string_view(ctx.input).substr(result.start, result.end - result.start); + } + + auto node_id = ctx.ast.add_node( + "", + p.tag, + result.start, + result.end, + text, + std::move(result.nodes), + result.need_more_input() + ); + + return common_peg_parse_result(result.type, result.start, result.end, { node_id }); + } + + return result; + } + + common_peg_parse_result operator()(const common_peg_ref_parser & p) { + auto rule_id = arena.get_rule(p.name); + return arena.parse(rule_id, ctx, start_pos); + } + + common_peg_parse_result operator()(const common_peg_atomic_parser & p) { + auto result = arena.parse(p.child, ctx, start_pos); + if (result.need_more_input()) { + // Clear nodes so they don't propagate up. + result.nodes.clear(); + } + return result; + } +}; + +common_peg_parse_result common_peg_arena::parse(common_peg_parse_context & ctx, size_t start) const { + if (root_ == COMMON_PEG_INVALID_PARSER_ID) { + throw std::runtime_error("No root parser set"); + } + return parse(root_, ctx, start); +} + +common_peg_parse_result common_peg_arena::parse(common_peg_parser_id id, common_peg_parse_context & ctx, size_t start) const { + // Execute parser + const auto & parser = parsers_.at(id); + parser_executor exec(*this, ctx, start); + return std::visit(exec, parser); +} + +common_peg_parser_id common_peg_arena::resolve_ref(common_peg_parser_id id) { + const auto & parser = parsers_.at(id); + if (auto ref = std::get_if(&parser)) { + return get_rule(ref->name); + } + return id; +} + +void common_peg_arena::resolve_refs() { + // Walk through all parsers and replace refs with their corresponding rule IDs + for (auto & parser : parsers_) { + std::visit([this](auto & p) { + using T = std::decay_t; + + if constexpr (std::is_same_v) { + for (auto & child : p.children) { + child = resolve_ref(child); + } + } else if constexpr (std::is_same_v) { + for (auto & child : p.children) { + child = resolve_ref(child); + } + } else if constexpr (std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v) { + p.child = resolve_ref(p.child); + } else if constexpr (std::is_same_v) { + p.child = resolve_ref(p.child); + } else if constexpr (std::is_same_v) { + p.child = resolve_ref(p.child); + } else if constexpr (std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v) { + // These rules do not have children + } else { + static_assert(is_always_false_v); + } + }, parser); + } + + // Also flatten root if it's a ref + if (root_ != COMMON_PEG_INVALID_PARSER_ID) { + root_ = resolve_ref(root_); + } +} + +std::string common_peg_arena::dump(common_peg_parser_id id) const { + const auto & parser = parsers_.at(id); + + return std::visit([this](const auto & p) -> std::string { + using T = std::decay_t; + + if constexpr (std::is_same_v) { + return "Epsilon"; + } else if constexpr (std::is_same_v) { + return "Start"; + } else if constexpr (std::is_same_v) { + return "End"; + } else if constexpr (std::is_same_v) { + return "Literal(" + p.literal + ")"; + } else if constexpr (std::is_same_v) { + std::vector parts; + for (const auto & child : p.children) { + parts.push_back(dump(child)); + } + return "Sequence(" + string_join(parts, ", ") + ")"; + } else if constexpr (std::is_same_v) { + std::vector parts; + for (const auto & child : p.children) { + parts.push_back(dump(child)); + } + return "Choice(" + string_join(parts, ", ") + ")"; + } else if constexpr (std::is_same_v) { + if (p.max_count == -1) { + return "Repetition(" + dump(p.child) + ", " + std::to_string(p.min_count) + ", unbounded)"; + } + return "Repetition(" + dump(p.child) + ", " + std::to_string(p.min_count) + ", " + std::to_string(p.max_count) + ")"; + } else if constexpr (std::is_same_v) { + return "And(" + dump(p.child) + ")"; + } else if constexpr (std::is_same_v) { + return "Not(" + dump(p.child) + ")"; + } else if constexpr (std::is_same_v) { + return "Any"; + } else if constexpr (std::is_same_v) { + return "Space"; + } else if constexpr (std::is_same_v) { + if (p.max_count == -1) { + return "CharRepeat(" + p.pattern + ", " + std::to_string(p.min_count) + ", unbounded)"; + } + return "CharRepeat(" + p.pattern + ", " + std::to_string(p.min_count) + ", " + std::to_string(p.max_count) + ")"; + } else if constexpr (std::is_same_v) { + return "JsonString()"; + } else if constexpr (std::is_same_v) { + return "Until(" + string_join(p.delimiters, " | ") + ")"; + } else if constexpr (std::is_same_v) { + return "Schema(" + dump(p.child) + ", " + (p.schema ? p.schema->dump() : "null") + ")"; + } else if constexpr (std::is_same_v) { + return "Rule(" + p.name + ", " + dump(p.child) + ")"; + } else if constexpr (std::is_same_v) { + return "Ref(" + p.name + ")"; + } else { + return "Unknown"; + } + }, parser); +} + +common_peg_parser & common_peg_parser::operator=(const common_peg_parser & other) { + id_ = other.id_; + return *this; +} + +common_peg_parser & common_peg_parser::operator+=(const common_peg_parser & other) { + id_ = builder_.sequence({id_, other.id_}); + return *this; +} + +common_peg_parser & common_peg_parser::operator|=(const common_peg_parser & other) { + id_ = builder_.choice({id_, other.id_}); + return *this; +} + +common_peg_parser common_peg_parser::operator+(const common_peg_parser & other) const { + return builder_.sequence({id_, other.id_}); +} + +common_peg_parser common_peg_parser::operator|(const common_peg_parser & other) const { + return builder_.choice({id_, other.id_}); +} + +common_peg_parser common_peg_parser::operator<<(const common_peg_parser & other) const { + return builder_.sequence({id_, builder_.space(), other.id_}); +} + +common_peg_parser common_peg_parser::operator+(const char * str) const { + return *this + builder_.literal(str); +} + +common_peg_parser common_peg_parser::operator+(const std::string & str) const { + return *this + builder_.literal(str); +} + +common_peg_parser common_peg_parser::operator<<(const char * str) const { + return *this << builder_.literal(str); +} + +common_peg_parser common_peg_parser::operator<<(const std::string & str) const { + return *this << builder_.literal(str); +} + +common_peg_parser common_peg_parser::operator|(const char * str) const { + return *this | builder_.literal(str); +} + +common_peg_parser common_peg_parser::operator|(const std::string & str) const { + return *this | builder_.literal(str); +} + +common_peg_parser operator+(const char * str, const common_peg_parser & p) { + return p.builder().literal(str) + p; +} + +common_peg_parser operator+(const std::string & str, const common_peg_parser & p) { + return operator+(str.c_str(), p); +} + +common_peg_parser operator<<(const char * str, const common_peg_parser & p) { + return p.builder().literal(str) << p; +} + +common_peg_parser operator<<(const std::string & str, const common_peg_parser & p) { + return operator<<(str.c_str(), p); +} + +common_peg_parser operator|(const char * str, const common_peg_parser & p) { + return p.builder().literal(str) | p; +} + +common_peg_parser operator|(const std::string & str, const common_peg_parser & p) { + return operator|(str.c_str(), p); +} + +static std::string rule_name(const std::string & name) { + static const std::regex invalid_rule_chars_re("[^a-zA-Z0-9-]+"); + return std::regex_replace(name, invalid_rule_chars_re, "-"); +} + +common_peg_parser_builder::common_peg_parser_builder() {} + +common_peg_parser common_peg_parser_builder::sequence(const std::vector & parsers) { + // Flatten nested sequences + std::vector flattened; + for (const auto & p : parsers) { + const auto & parser = arena_.get(p); + if (auto seq = std::get_if(&parser)) { + flattened.insert(flattened.end(), seq->children.begin(), seq->children.end()); + } else { + flattened.push_back(p); + } + } + return wrap(arena_.add_parser(common_peg_sequence_parser{flattened})); +} + +common_peg_parser common_peg_parser_builder::sequence(const std::vector & parsers) { + std::vector ids; + ids.reserve(parsers.size()); + for (const auto & p : parsers) { + ids.push_back(p.id()); + } + return sequence(ids); +} + +common_peg_parser common_peg_parser_builder::sequence(std::initializer_list parsers) { + std::vector ids; + ids.reserve(parsers.size()); + for (const auto & p : parsers) { + ids.push_back(p.id()); + } + return sequence(ids); +} + +common_peg_parser common_peg_parser_builder::choice(const std::vector & parsers) { + // Flatten nested choices + std::vector flattened; + for (const auto & p : parsers) { + const auto & parser = arena_.get(p); + if (auto choice = std::get_if(&parser)) { + flattened.insert(flattened.end(), choice->children.begin(), choice->children.end()); + } else { + flattened.push_back(p); + } + } + return wrap(arena_.add_parser(common_peg_choice_parser{flattened})); +} + +common_peg_parser common_peg_parser_builder::choice(const std::vector & parsers) { + std::vector ids; + ids.reserve(parsers.size()); + for (const auto & p : parsers) { + ids.push_back(p.id()); + } + return choice(ids); +} + +common_peg_parser common_peg_parser_builder::choice(std::initializer_list parsers) { + std::vector ids; + ids.reserve(parsers.size()); + for (const auto & p : parsers) { + ids.push_back(p.id()); + } + return choice(ids); +} + +common_peg_parser common_peg_parser_builder::chars(const std::string & classes, int min, int max) { + auto [ranges, negated] = parse_char_classes(classes); + return wrap(arena_.add_parser(common_peg_chars_parser{classes, ranges, negated, min, max})); +} + +common_peg_parser common_peg_parser_builder::schema(const common_peg_parser & p, const std::string & name, const nlohmann::ordered_json & schema, bool raw) { + return wrap(arena_.add_parser(common_peg_schema_parser{p.id(), name, std::make_shared(schema), raw})); +} + +common_peg_parser common_peg_parser_builder::rule(const std::string & name, const common_peg_parser & p, bool trigger) { + auto clean_name = rule_name(name); + auto rule_id = arena_.add_parser(common_peg_rule_parser{clean_name, p.id(), trigger}); + arena_.add_rule(clean_name, rule_id); + return ref(clean_name); +} + +common_peg_parser common_peg_parser_builder::rule(const std::string & name, const std::function & builder_fn, bool trigger) { + auto clean_name = rule_name(name); + if (arena_.has_rule(clean_name)) { + return ref(clean_name); + } + + // Create placeholder rule to allow recursive references + auto placeholder = any(); // Temporary placeholder + auto placeholder_rule_id = arena_.add_parser(common_peg_rule_parser{clean_name, placeholder.id(), trigger}); + arena_.add_rule(clean_name, placeholder_rule_id); + + // Build the actual parser + auto parser = builder_fn(); + + // Replace placeholder with actual rule + auto rule_id = arena_.add_parser(common_peg_rule_parser{clean_name, parser.id(), trigger}); + arena_.rules_[clean_name] = rule_id; + + return ref(clean_name); +} + +void common_peg_parser_builder::set_root(const common_peg_parser & p) { + arena_.set_root(p.id()); +} + +common_peg_arena common_peg_parser_builder::build() { + arena_.resolve_refs(); + return std::move(arena_); +} + +// JSON parsers +common_peg_parser common_peg_parser_builder::json_number() { + return rule("json-number", [this]() { + auto digit1_9 = chars("[1-9]", 1, 1); + auto digits = chars("[0-9]"); + auto int_part = choice({literal("0"), sequence({digit1_9, chars("[0-9]", 0, -1)})}); + auto frac = sequence({literal("."), digits}); + auto exp = sequence({choice({literal("e"), literal("E")}), optional(chars("[+-]", 1, 1)), digits}); + return sequence({optional(literal("-")), int_part, optional(frac), optional(exp), space()}); + }); +} + +common_peg_parser common_peg_parser_builder::json_string() { + return rule("json-string", [this]() { + return sequence({literal("\""), json_string_content(), literal("\""), space()}); + }); +} + +common_peg_parser common_peg_parser_builder::json_bool() { + return rule("json-bool", [this]() { + return sequence({choice({literal("true"), literal("false")}), space()}); + }); +} + +common_peg_parser common_peg_parser_builder::json_null() { + return rule("json-null", [this]() { + return sequence({literal("null"), space()}); + }); +} + +common_peg_parser common_peg_parser_builder::json_object() { + return rule("json-object", [this]() { + auto ws = space(); + auto member = sequence({json_string(), ws, literal(":"), ws, json()}); + auto members = sequence({member, zero_or_more(sequence({ws, literal(","), ws, member}))}); + return sequence({ + literal("{"), + ws, + choice({ + literal("}"), + sequence({members, ws, literal("}")}) + }), + ws + }); + }); +} + +common_peg_parser common_peg_parser_builder::json_array() { + return rule("json-array", [this]() { + auto ws = space(); + auto elements = sequence({json(), zero_or_more(sequence({literal(","), ws, json()}))}); + return sequence({ + literal("["), + ws, + choice({ + literal("]"), + sequence({elements, ws, literal("]")}) + }), + ws + }); + }); +} + +common_peg_parser common_peg_parser_builder::json() { + return rule("json-value", [this]() { + return choice({ + json_object(), + json_array(), + json_string(), + json_number(), + json_bool(), + json_null() + }); + }); +} + +common_peg_parser common_peg_parser_builder::json_string_content() { + return wrap(arena_.add_parser(common_peg_json_string_parser{})); +} + +common_peg_parser common_peg_parser_builder::json_member(const std::string & key, const common_peg_parser & p) { + auto ws = space(); + return sequence({ + literal("\"" + key + "\""), + ws, + literal(":"), + ws, + p, + }); +} + + +static std::string gbnf_escape_char_class(char c) { + switch (c) { + case '\n': return "\\n"; + case '\t': return "\\t"; + case '\r': return "\\r"; + case '\\': return "\\\\"; + case ']': return "\\]"; + case '[': return "\\["; + default: return std::string(1, c); + } +} + +static std::string gbnf_excluding_pattern(const std::vector & strings) { + trie matcher(strings); + auto pieces = matcher.collect_prefix_and_next(); + + std::string pattern; + for (size_t i = 0; i < pieces.size(); ++i) { + if (i > 0) { + pattern += " | "; + } + + const auto & pre = pieces[i].prefix; + const auto & chars = pieces[i].next_chars; + + std::string cls; + cls.reserve(chars.size()); + for (const auto & ch : chars) { + cls += gbnf_escape_char_class(ch); + } + + if (!pre.empty()) { + pattern += gbnf_format_literal(pre) + " [^" + cls + "]"; + } else { + pattern += "[^" + cls + "]"; + } + } + + return "(" + pattern + ")*"; +} + +static std::unordered_set collect_reachable_rules( + const common_peg_arena & arena, + const common_peg_parser_id & rule +) { + std::unordered_set reachable; + std::unordered_set visited; + + std::function visit = [&](common_peg_parser_id id) { + const auto & parser = arena.get(id); + + std::visit([&](const auto & p) { + using T = std::decay_t; + + if constexpr (std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v) { + // These parsers do not have any children + } else if constexpr (std::is_same_v) { + for (auto child : p.children) { + visit(child); + } + } else if constexpr (std::is_same_v) { + for (auto child : p.children) { + visit(child); + } + } else if constexpr (std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v) { + visit(p.child); + } else if constexpr (std::is_same_v) { + if (visited.find(p.name) == visited.end()) { + visited.insert(p.name); + reachable.insert(p.name); + visit(p.child); + } + } else if constexpr (std::is_same_v) { + // Traverse rules so we pick up everything + auto referenced_rule = arena.get_rule(p.name); + visit(referenced_rule); + } else { + static_assert(is_always_false_v); + } + }, parser); + }; + + visit(rule); + return reachable; +} + +// GBNF generation implementation +void common_peg_arena::build_grammar(const common_grammar_builder & builder, bool lazy) const { + // Generate GBNF for a parser + std::function to_gbnf = [&](common_peg_parser_id id) -> std::string { + const auto & parser = parsers_.at(id); + + return std::visit([&](const auto & p) -> std::string { + using T = std::decay_t; + + if constexpr (std::is_same_v || + std::is_same_v || + std::is_same_v) { + return ""; + } else if constexpr (std::is_same_v) { + return gbnf_format_literal(p.literal); + } else if constexpr (std::is_same_v) { + std::string s; + for (const auto & child : p.children) { + if (!s.empty()) { + s += " "; + } + auto child_gbnf = to_gbnf(child); + const auto & child_parser = parsers_.at(child); + if (std::holds_alternative(child_parser) || + std::holds_alternative(child_parser)) { + s += "(" + child_gbnf + ")"; + } else { + s += child_gbnf; + } + } + return s; + } else if constexpr (std::is_same_v) { + std::string s; + for (const auto & child : p.children) { + if (!s.empty()) { + s += " | "; + } + auto child_gbnf = to_gbnf(child); + const auto & child_parser = parsers_.at(child); + if (std::holds_alternative(child_parser)) { + s += "(" + child_gbnf + ")"; + } else { + s += child_gbnf; + } + } + return s; + } else if constexpr (std::is_same_v) { + auto child_gbnf = to_gbnf(p.child); + const auto & child_parser = parsers_.at(p.child); + if (std::holds_alternative(child_parser) || + std::holds_alternative(child_parser)) { + child_gbnf = "(" + child_gbnf + ")"; + } + if (p.min_count == 0 && p.max_count == 1) { + return child_gbnf + "?"; + } + if (p.min_count == 0 && p.max_count == -1) { + return child_gbnf + "*"; + } + if (p.min_count == 1 && p.max_count == -1) { + return child_gbnf + "+"; + } + if (p.max_count == -1) { + return child_gbnf + "{" + std::to_string(p.min_count) + ",}"; + } + if (p.min_count == p.max_count) { + if (p.min_count == 1) { + return child_gbnf; + } + return child_gbnf + "{" + std::to_string(p.min_count) + "}"; + } + return child_gbnf + "{" + std::to_string(p.min_count) + "," + std::to_string(p.max_count) + "}"; + } else if constexpr (std::is_same_v || std::is_same_v) { + return ""; // Lookahead not supported in GBNF + } else if constexpr (std::is_same_v) { + return "."; + } else if constexpr (std::is_same_v) { + return "space"; + } else if constexpr (std::is_same_v) { + std::string result = p.pattern; + if (p.min_count == 0 && p.max_count == 1) { + return result + "?"; + } + if (p.min_count == 0 && p.max_count == -1) { + return result + "*"; + } + if (p.min_count == 1 && p.max_count == -1) { + return result + "+"; + } + if (p.max_count == -1) { + return result + "{" + std::to_string(p.min_count) + ",}"; + } + if (p.min_count == p.max_count) { + if (p.min_count == 1) { + return result; + } + return result + "{" + std::to_string(p.min_count) + "}"; + } + return result + "{" + std::to_string(p.min_count) + "," + std::to_string(p.max_count) + "}"; + } else if constexpr (std::is_same_v) { + return R"(( [^"\\] | "\\" ( ["\\/ bfnrt] | "u" [0-9a-fA-F]{4} ) )*)"; + } else if constexpr (std::is_same_v) { + if (p.delimiters.empty()) { + return ".*"; + } + return gbnf_excluding_pattern(p.delimiters); + } else if constexpr (std::is_same_v) { + if (p.schema) { + if (p.raw && p.schema->contains("type") && p.schema->at("type").is_string() && p.schema->at("type") == "string") { + // TODO: Implement more comprehensive grammar generation for raw strings. + // For now, use the grammar emitted from the underlying parser. + return to_gbnf(p.child); + } + return builder.add_schema(p.name, *p.schema); + } + return to_gbnf(p.child); + } else if constexpr (std::is_same_v) { + return p.name; + } else if constexpr (std::is_same_v) { + // Refs should not exist after flattening, but kept just in case + return p.name; + } else if constexpr (std::is_same_v) { + return to_gbnf(p.child); + } else if constexpr (std::is_same_v) { + return to_gbnf(p.child); + } else { + static_assert(is_always_false_v); + } + }, parser); + }; + + // Collect reachable rules + std::unordered_set reachable_rules; + + if (lazy) { + // Collect rules reachable from trigger rules + for (const auto & [name, id] : rules_) { + const auto & parser = parsers_.at(id); + if (auto rule = std::get_if(&parser)) { + if (rule->trigger) { + // Mark trigger as reachable and visit it + reachable_rules.insert(name); + auto add_rules = collect_reachable_rules(*this, id); + reachable_rules.insert(add_rules.begin(), add_rules.end()); + } + } + } + } else { + // Collect rules reachable from root + reachable_rules = collect_reachable_rules(*this, root_); + } + + // Create GBNF rules for all reachable rules + for (const auto & [name, rule_id] : rules_) { + if (reachable_rules.find(name) == reachable_rules.end()) { + continue; + } + + const auto & parser = parsers_.at(rule_id); + if (auto rule = std::get_if(&parser)) { + builder.add_rule(rule->name, to_gbnf(rule->child)); + } + } + + if (lazy) { + // Generate root rule from trigger rules only + std::vector trigger_names; + for (const auto & [name, rule_id] : rules_) { + const auto & parser = parsers_.at(rule_id); + if (auto rule = std::get_if(&parser)) { + if (rule->trigger) { + trigger_names.push_back(rule->name); + } + } + } + + // Sort for predictable order + std::sort(trigger_names.begin(), trigger_names.end()); + builder.add_rule("root", string_join(trigger_names, " | ")); + } else if (root_ != COMMON_PEG_INVALID_PARSER_ID) { + builder.add_rule("root", to_gbnf(root_)); + } +} + +static nlohmann::json serialize_parser_variant(const common_peg_parser_variant & variant) { + using json = nlohmann::json; + + return std::visit([](const auto & p) -> json { + using T = std::decay_t; + + if constexpr (std::is_same_v) { + return json{{"type", "epsilon"}}; + } else if constexpr (std::is_same_v) { + return json{{"type", "start"}}; + } else if constexpr (std::is_same_v) { + return json{{"type", "end"}}; + } else if constexpr (std::is_same_v) { + return json{{"type", "literal"}, {"literal", p.literal}}; + } else if constexpr (std::is_same_v) { + return json{{"type", "sequence"}, {"children", p.children}}; + } else if constexpr (std::is_same_v) { + return json{{"type", "choice"}, {"children", p.children}}; + } else if constexpr (std::is_same_v) { + return json{ + {"type", "repetition"}, + {"child", p.child}, + {"min_count", p.min_count}, + {"max_count", p.max_count} + }; + } else if constexpr (std::is_same_v) { + return json{{"type", "and"}, {"child", p.child}}; + } else if constexpr (std::is_same_v) { + return json{{"type", "not"}, {"child", p.child}}; + } else if constexpr (std::is_same_v) { + return json{{"type", "any"}}; + } else if constexpr (std::is_same_v) { + return json{{"type", "space"}}; + } else if constexpr (std::is_same_v) { + json ranges = json::array(); + for (const auto & range : p.ranges) { + ranges.push_back({{"start", range.start}, {"end", range.end}}); + } + return json{ + {"type", "chars"}, + {"pattern", p.pattern}, + {"ranges", ranges}, + {"negated", p.negated}, + {"min_count", p.min_count}, + {"max_count", p.max_count} + }; + } else if constexpr (std::is_same_v) { + return json{{"type", "json_string"}}; + } else if constexpr (std::is_same_v) { + return json{{"type", "until"}, {"delimiters", p.delimiters}}; + } else if constexpr (std::is_same_v) { + return json{ + {"type", "schema"}, + {"child", p.child}, + {"name", p.name}, + {"schema", p.schema ? *p.schema : nullptr}, + {"raw", p.raw} + }; + } else if constexpr (std::is_same_v) { + return json{ + {"type", "rule"}, + {"name", p.name}, + {"child", p.child}, + {"trigger", p.trigger} + }; + } else if constexpr (std::is_same_v) { + return json{{"type", "ref"}, {"name", p.name}}; + } else if constexpr (std::is_same_v) { + return json{{"type", "atomic"}, {"child", p.child}}; + } else if constexpr (std::is_same_v) { + return json{ + {"type", "tag"}, + {"child", p.child}, + {"tag", p.tag} + }; + } + }, variant); +} + +nlohmann::json common_peg_arena::to_json() const { + auto parsers = nlohmann::json::array(); + for (const auto & parser : parsers_) { + parsers.push_back(serialize_parser_variant(parser)); + } + return nlohmann::json{ + {"parsers", parsers}, + {"rules", rules_}, + {"root", root_} + }; +} + +static common_peg_parser_variant deserialize_parser_variant(const nlohmann::json & j) { + if (!j.contains("type") || !j["type"].is_string()) { + throw std::runtime_error("Parser variant JSON missing or invalid 'type' field"); + } + + std::string type = j["type"]; + + if (type == "epsilon") { + return common_peg_epsilon_parser{}; + } + if (type == "start") { + return common_peg_start_parser{}; + } + if (type == "end") { + return common_peg_end_parser{}; + } + if (type == "literal") { + if (!j.contains("literal") || !j["literal"].is_string()) { + throw std::runtime_error("literal parser missing or invalid 'literal' field"); + } + return common_peg_literal_parser{j["literal"]}; + } + if (type == "sequence") { + if (!j.contains("children") || !j["children"].is_array()) { + throw std::runtime_error("sequence parser missing or invalid 'children' field"); + } + return common_peg_sequence_parser{j["children"].get>()}; + } + if (type == "choice") { + if (!j.contains("children") || !j["children"].is_array()) { + throw std::runtime_error("choice parser missing or invalid 'children' field"); + } + return common_peg_choice_parser{j["children"].get>()}; + } + if (type == "repetition") { + if (!j.contains("child") || !j.contains("min_count") || !j.contains("max_count")) { + throw std::runtime_error("repetition parser missing required fields"); + } + return common_peg_repetition_parser{ + j["child"].get(), + j["min_count"].get(), + j["max_count"].get() + }; + } + if (type == "and") { + if (!j.contains("child")) { + throw std::runtime_error("and parser missing 'child' field"); + } + return common_peg_and_parser{j["child"].get()}; + } + if (type == "not") { + if (!j.contains("child")) { + throw std::runtime_error("not parser missing 'child' field"); + } + return common_peg_not_parser{j["child"].get()}; + } + if (type == "any") { + return common_peg_any_parser{}; + } + if (type == "space") { + return common_peg_space_parser{}; + } + if (type == "chars") { + if (!j.contains("pattern") || !j.contains("ranges") || !j.contains("negated") || + !j.contains("min_count") || !j.contains("max_count")) { + throw std::runtime_error("chars parser missing required fields"); + } + common_peg_chars_parser parser; + parser.pattern = j["pattern"]; + parser.negated = j["negated"]; + parser.min_count = j["min_count"]; + parser.max_count = j["max_count"]; + for (const auto & range_json : j["ranges"]) { + if (!range_json.contains("start") || !range_json.contains("end")) { + throw std::runtime_error("char_range missing 'start' or 'end' field"); + } + parser.ranges.push_back({ + range_json["start"].get(), + range_json["end"].get() + }); + } + return parser; + } + if (type == "json_string") { + return common_peg_json_string_parser{}; + } + if (type == "until") { + if (!j.contains("delimiters") || !j["delimiters"].is_array()) { + throw std::runtime_error("until parser missing or invalid 'delimiters' field"); + } + return common_peg_until_parser{j["delimiters"].get>()}; + } + if (type == "schema") { + if (!j.contains("child") || !j.contains("name") || !j.contains("schema") || !j.contains("raw")) { + throw std::runtime_error("schema parser missing required fields"); + } + common_peg_schema_parser parser; + parser.child = j["child"].get(); + parser.name = j["name"]; + if (!j["schema"].is_null()) { + parser.schema = std::make_shared(j["schema"]); + } + parser.raw = j["raw"].get(); + return parser; + } + if (type == "rule") { + if (!j.contains("name") || !j.contains("child") || !j.contains("trigger")) { + throw std::runtime_error("rule parser missing required fields"); + } + return common_peg_rule_parser{ + j["name"].get(), + j["child"].get(), + j["trigger"].get() + }; + } + if (type == "ref") { + if (!j.contains("name") || !j["name"].is_string()) { + throw std::runtime_error("ref parser missing or invalid 'name' field"); + } + return common_peg_ref_parser{j["name"]}; + } + if (type == "atomic") { + if (!j.contains("child")) { + throw std::runtime_error("tag parser missing required fields"); + } + return common_peg_atomic_parser{ + j["child"].get(), + }; + } + if (type == "tag") { + if (!j.contains("child") || !j.contains("tag")) { + throw std::runtime_error("tag parser missing required fields"); + } + return common_peg_tag_parser{ + j["child"].get(), + j["tag"].get(), + }; + } + + throw std::runtime_error("Unknown parser type: " + type); +} + +common_peg_arena common_peg_arena::from_json(const nlohmann::json & j) { + if (!j.contains("parsers") || !j["parsers"].is_array()) { + throw std::runtime_error("JSON missing or invalid 'parsers' array"); + } + if (!j.contains("rules") || !j["rules"].is_object()) { + throw std::runtime_error("JSON missing or invalid 'rules' object"); + } + if (!j.contains("root")) { + throw std::runtime_error("JSON missing 'root' field"); + } + + common_peg_arena arena; + + const auto & parsers_json = j["parsers"]; + arena.parsers_.reserve(parsers_json.size()); + for (const auto & parser_json : parsers_json) { + arena.parsers_.push_back(deserialize_parser_variant(parser_json)); + } + + arena.rules_ = j["rules"].get>(); + + for (const auto & [name, id] : arena.rules_) { + if (id >= arena.parsers_.size()) { + throw std::runtime_error("Rule '" + name + "' references invalid parser ID: " + std::to_string(id)); + } + } + + arena.root_ = j["root"].get(); + if (arena.root_ != COMMON_PEG_INVALID_PARSER_ID && arena.root_ >= arena.parsers_.size()) { + throw std::runtime_error("Root references invalid parser ID: " + std::to_string(arena.root_)); + } + + return arena; +} + +std::string common_peg_arena::save() const { + return to_json().dump(); +} + +void common_peg_arena::load(const std::string & data) { + *this = from_json(nlohmann::json::parse(data)); +} + +common_peg_arena build_peg_parser(const std::function & fn) { + common_peg_parser_builder builder; + builder.set_root(fn(builder)); + return builder.build(); +} diff --git a/common/peg-parser.h b/common/peg-parser.h new file mode 100644 index 00000000000..1cd640365f2 --- /dev/null +++ b/common/peg-parser.h @@ -0,0 +1,459 @@ +#pragma once + +#include + +#include +#include +#include +#include +#include +#include +#include + +struct common_grammar_builder; + +class common_peg_parser_builder; + +using common_peg_parser_id = size_t; +constexpr common_peg_parser_id COMMON_PEG_INVALID_PARSER_ID = static_cast(-1); + +using common_peg_ast_id = size_t; +constexpr common_peg_ast_id COMMON_PEG_INVALID_AST_ID = static_cast(-1); + +// Lightweight wrapper around common_peg_parser_id for convenience +class common_peg_parser { + common_peg_parser_id id_; + common_peg_parser_builder & builder_; + + public: + common_peg_parser(const common_peg_parser & other) : id_(other.id_), builder_(other.builder_) {} + common_peg_parser(common_peg_parser_id id, common_peg_parser_builder & builder) : id_(id), builder_(builder) {} + + common_peg_parser & operator=(const common_peg_parser & other); + common_peg_parser & operator+=(const common_peg_parser & other); + common_peg_parser & operator|=(const common_peg_parser & other); + + operator common_peg_parser_id() const { return id_; } + common_peg_parser_id id() const { return id_; } + + common_peg_parser_builder & builder() const { return builder_; } + + // Creates a sequence + common_peg_parser operator+(const common_peg_parser & other) const; + + // Creates a sequence separated by spaces. + common_peg_parser operator<<(const common_peg_parser & other) const; + + // Creates a choice + common_peg_parser operator|(const common_peg_parser & other) const; + + common_peg_parser operator+(const char * str) const; + common_peg_parser operator+(const std::string & str) const; + common_peg_parser operator<<(const char * str) const; + common_peg_parser operator<<(const std::string & str) const; + common_peg_parser operator|(const char * str) const; + common_peg_parser operator|(const std::string & str) const; +}; + +common_peg_parser operator+(const char * str, const common_peg_parser & p); +common_peg_parser operator+(const std::string & str, const common_peg_parser & p); +common_peg_parser operator<<(const char * str, const common_peg_parser & p); +common_peg_parser operator<<(const std::string & str, const common_peg_parser & p); +common_peg_parser operator|(const char * str, const common_peg_parser & p); +common_peg_parser operator|(const std::string & str, const common_peg_parser & p); + +enum common_peg_parse_result_type { + COMMON_PEG_PARSE_RESULT_FAIL = 0, + COMMON_PEG_PARSE_RESULT_SUCCESS = 1, + COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT = 2, +}; + +const char * common_peg_parse_result_type_name(common_peg_parse_result_type type); + +struct common_peg_ast_node { + common_peg_ast_id id; + std::string rule; + std::string tag; + size_t start; + size_t end; + std::string_view text; + std::vector children; + + bool is_partial = false; +}; + +struct common_peg_parse_result; + +using common_peg_ast_visitor = std::function; + +class common_peg_ast_arena { + std::vector nodes_; + public: + common_peg_ast_id add_node( + const std::string & rule, + const std::string & tag, + size_t start, + size_t end, + std::string_view text, + std::vector children, + bool is_partial = false + ) { + common_peg_ast_id id = nodes_.size(); + nodes_.push_back({id, rule, tag, start, end, text, std::move(children), is_partial}); + return id; + } + + const common_peg_ast_node & get(common_peg_ast_id id) const { return nodes_.at(id); } + + size_t size() const { return nodes_.size(); } + + void clear() { nodes_.clear(); } + + void visit(common_peg_ast_id id, const common_peg_ast_visitor & visitor) const; + void visit(const common_peg_parse_result & result, const common_peg_ast_visitor & visitor) const; +}; + +struct common_peg_parse_result { + common_peg_parse_result_type type = COMMON_PEG_PARSE_RESULT_FAIL; + size_t start = 0; + size_t end = 0; + + std::vector nodes; + + common_peg_parse_result() = default; + + common_peg_parse_result(common_peg_parse_result_type type, size_t start) + : type(type), start(start), end(start) {} + + common_peg_parse_result(common_peg_parse_result_type type, size_t start, size_t end) + : type(type), start(start), end(end) {} + + common_peg_parse_result(common_peg_parse_result_type type, size_t start, size_t end, std::vector nodes) + : type(type), start(start), end(end), nodes(std::move(nodes)) {} + + bool fail() const { return type == COMMON_PEG_PARSE_RESULT_FAIL; } + bool need_more_input() const { return type == COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT; } + bool success() const { return type == COMMON_PEG_PARSE_RESULT_SUCCESS; } +}; + +struct common_peg_parse_context { + std::string input; + bool is_partial; + common_peg_ast_arena ast; + + int parse_depth; + + common_peg_parse_context() + : is_partial(false), parse_depth(0) {} + + common_peg_parse_context(const std::string & input) + : input(input), is_partial(false), parse_depth(0) {} + + common_peg_parse_context(const std::string & input, bool is_partial) + : input(input), is_partial(is_partial), parse_depth(0) {} +}; + +class common_peg_arena; + +// Parser variants +struct common_peg_epsilon_parser {}; + +struct common_peg_start_parser {}; + +struct common_peg_end_parser {}; + +struct common_peg_literal_parser { + std::string literal; +}; + +struct common_peg_sequence_parser { + std::vector children; +}; + +struct common_peg_choice_parser { + std::vector children; +}; + +struct common_peg_repetition_parser { + common_peg_parser_id child; + int min_count; + int max_count; // -1 for unbounded +}; + +struct common_peg_and_parser { + common_peg_parser_id child; +}; + +struct common_peg_not_parser { + common_peg_parser_id child; +}; + +struct common_peg_any_parser {}; + +struct common_peg_space_parser {}; + +struct common_peg_chars_parser { + struct char_range { + uint32_t start; + uint32_t end; + bool contains(uint32_t codepoint) const { return codepoint >= start && codepoint <= end; } + }; + + std::string pattern; + std::vector ranges; + bool negated; + int min_count; + int max_count; // -1 for unbounded +}; + +struct common_peg_json_string_parser {}; + +struct common_peg_until_parser { + std::vector delimiters; +}; + +struct common_peg_schema_parser { + common_peg_parser_id child; + std::string name; + std::shared_ptr schema; + + // Indicates if the GBNF should accept a raw string that matches the schema. + bool raw; +}; + +struct common_peg_rule_parser { + std::string name; + common_peg_parser_id child; + bool trigger; +}; + +struct common_peg_ref_parser { + std::string name; +}; + +struct common_peg_atomic_parser { + common_peg_parser_id child; +}; + +struct common_peg_tag_parser { + common_peg_parser_id child; + std::string tag; +}; + +// Variant holding all parser types +using common_peg_parser_variant = std::variant< + common_peg_epsilon_parser, + common_peg_start_parser, + common_peg_end_parser, + common_peg_literal_parser, + common_peg_sequence_parser, + common_peg_choice_parser, + common_peg_repetition_parser, + common_peg_and_parser, + common_peg_not_parser, + common_peg_any_parser, + common_peg_space_parser, + common_peg_chars_parser, + common_peg_json_string_parser, + common_peg_until_parser, + common_peg_schema_parser, + common_peg_rule_parser, + common_peg_ref_parser, + common_peg_atomic_parser, + common_peg_tag_parser +>; + +class common_peg_arena { + std::vector parsers_; + std::unordered_map rules_; + common_peg_parser_id root_ = COMMON_PEG_INVALID_PARSER_ID; + + public: + const common_peg_parser_variant & get(common_peg_parser_id id) const { return parsers_.at(id); } + common_peg_parser_variant & get(common_peg_parser_id id) { return parsers_.at(id); } + + size_t size() const { return parsers_.size(); } + bool empty() const { return parsers_.empty(); } + + common_peg_parser_id get_rule(const std::string & name) const; + bool has_rule(const std::string & name) const { return rules_.find(name) != rules_.end(); } + + common_peg_parser_id root() const { return root_; } + void set_root(common_peg_parser_id id) { root_ = id; } + + common_peg_parse_result parse(common_peg_parse_context & ctx, size_t start = 0) const; + common_peg_parse_result parse(common_peg_parser_id id, common_peg_parse_context & ctx, size_t start) const; + + void resolve_refs(); + + void build_grammar(const common_grammar_builder & builder, bool lazy = false) const; + + std::string dump(common_peg_parser_id id) const; + + nlohmann::json to_json() const; + static common_peg_arena from_json(const nlohmann::json & j); + + std::string save() const; + void load(const std::string & data); + + friend class common_peg_parser_builder; + + private: + common_peg_parser_id add_parser(common_peg_parser_variant parser); + void add_rule(const std::string & name, common_peg_parser_id id); + + common_peg_parser_id resolve_ref(common_peg_parser_id id); +}; + +class common_peg_parser_builder { + common_peg_arena arena_; + + common_peg_parser wrap(common_peg_parser_id id) { return common_peg_parser(id, *this); } + common_peg_parser add(const common_peg_parser_variant & p) { return wrap(arena_.add_parser(p)); } + + public: + common_peg_parser_builder(); + + // Match nothing, always succeed. + // S -> ε + common_peg_parser eps() { return add(common_peg_epsilon_parser{}); } + + // Matches the start of the input. + // S -> ^ + common_peg_parser start() { return add(common_peg_start_parser{}); } + + // Matches the end of the input. + // S -> $ + common_peg_parser end() { return add(common_peg_end_parser{}); } + + // Matches an exact literal string. + // S -> "hello" + common_peg_parser literal(const std::string & literal) { return add(common_peg_literal_parser{literal}); } + + // Matches a sequence of parsers in order, all must succeed. + // S -> A B C + common_peg_parser sequence() { return add(common_peg_sequence_parser{}); } + common_peg_parser sequence(const std::vector & parsers); + common_peg_parser sequence(const std::vector & parsers); + common_peg_parser sequence(std::initializer_list parsers); + + // Matches the first parser that succeeds from a list of alternatives. + // S -> A | B | C + common_peg_parser choice() { return add(common_peg_choice_parser{}); } + common_peg_parser choice(const std::vector & parsers); + common_peg_parser choice(const std::vector & parsers); + common_peg_parser choice(std::initializer_list parsers); + + // Matches one or more repetitions of a parser. + // S -> A+ + common_peg_parser one_or_more(const common_peg_parser & p) { return repeat(p, 1, -1); } + + // Matches zero or more repetitions of a parser, always succeeds. + // S -> A* + common_peg_parser zero_or_more(const common_peg_parser & p) { return repeat(p, 0, -1); } + + // Matches zero or one occurrence of a parser, always succeeds. + // S -> A? + common_peg_parser optional(const common_peg_parser & p) { return repeat(p, 0, 1); } + + // Positive lookahead: succeeds if child parser succeeds, consumes no input. + // S -> &A + common_peg_parser peek(const common_peg_parser & p) { return add(common_peg_and_parser{p}); } + + // Negative lookahead: succeeds if child parser fails, consumes no input. + // S -> !A + common_peg_parser negate(const common_peg_parser & p) { return add(common_peg_not_parser{p}); } + + // Matches any single character. + // S -> . + common_peg_parser any() { return add(common_peg_any_parser{}); } + + // Matches between min and max repetitions of characters from a character class. + // S -> [a-z]{m,n} + // + // Use -1 for max to represent unbounded repetition (equivalent to {m,}) + common_peg_parser chars(const std::string & classes, int min = 1, int max = -1); + + // Creates a lightweight reference to a named rule (resolved during build()). + // Use this for forward references in recursive grammars. + // expr_ref -> expr + common_peg_parser ref(const std::string & name) { return add(common_peg_ref_parser{name}); } + + // Matches zero or more whitespace characters (space, tab, newline). + // S -> [ \t\n]* + common_peg_parser space() { return add(common_peg_space_parser{}); } + + // Matches all characters until a delimiter is found (delimiter not consumed). + // S -> (!delim .)* + common_peg_parser until(const std::string & delimiter) { return add(common_peg_until_parser{{delimiter}}); } + + // Matches all characters until one of the delimiters in the list is found (delimiter not consumed). + // S -> (!delim .)* + common_peg_parser until_one_of(const std::vector & delimiters) { return add(common_peg_until_parser{delimiters}); } + + // Matches everything + // S -> .* + common_peg_parser rest() { return until_one_of({}); } + + // Matches between min and max repetitions of a parser (inclusive). + // S -> A{m,n} + // Use -1 for max to represent unbounded repetition (equivalent to {m,}) + common_peg_parser repeat(const common_peg_parser & p, int min, int max) { return add(common_peg_repetition_parser{p, min,max}); } + + // Matches exactly n repetitions of a parser. + // S -> A{n} + common_peg_parser repeat(const common_peg_parser & p, int n) { return repeat(p, n, n); } + + // Creates a complete JSON parser supporting objects, arrays, strings, numbers, booleans, and null. + // value -> object | array | string | number | true | false | null + common_peg_parser json(); + common_peg_parser json_object(); + common_peg_parser json_string(); + common_peg_parser json_array(); + common_peg_parser json_number(); + common_peg_parser json_bool(); + common_peg_parser json_null(); + + // Matches JSON string content without the surrounding quotes. + // Useful for extracting content within a JSON string. + common_peg_parser json_string_content(); + + // Matches a JSON object member with a key and associated parser as the + // value. + common_peg_parser json_member(const std::string & key, const common_peg_parser & p); + + // Wraps a parser with JSON schema metadata for grammar generation. + // Used internally to convert JSON schemas to GBNF grammar rules. + common_peg_parser schema(const common_peg_parser & p, const std::string & name, const nlohmann::ordered_json & schema, bool raw = false); + + // Creates a named rule, stores it in the grammar, and returns a ref. + // If trigger=true, marks this rule as an entry point for lazy grammar generation. + // auto json = p.rule("json", json_obj | json_arr | ...) + common_peg_parser rule(const std::string & name, const common_peg_parser & p, bool trigger = false); + + // Creates a named rule using a builder function, and returns a ref. + // If trigger=true, marks this rule as an entry point for lazy grammar generation. + // auto json = p.rule("json", [&]() { return json_object() | json_array() | ... }) + common_peg_parser rule(const std::string & name, const std::function & builder, bool trigger = false); + + // Creates a trigger rule. When generating a lazy grammar from the parser, + // only trigger rules and descendents are emitted. + common_peg_parser trigger_rule(const std::string & name, const common_peg_parser & p) { return rule(name, p, true); } + common_peg_parser trigger_rule(const std::string & name, const std::function & builder) { return rule(name, builder, true); } + + // Creates an atomic parser. Atomic parsers do not create an AST node if + // the child results in a partial parse, i.e. NEEDS_MORE_INPUT. This is + // intended for situations where partial output is undesirable. + common_peg_parser atomic(const common_peg_parser & p) { return add(common_peg_atomic_parser{p}); } + + // Tags create nodes in the generated AST for semantic purposes. + // Unlike rules, you can tag multiple nodes with the same tag. + common_peg_parser tag(const std::string & tag, const common_peg_parser & p) { return add(common_peg_tag_parser{p.id(), tag}); } + + void set_root(const common_peg_parser & p); + + common_peg_arena build(); +}; + +// Helper function for building parsers +common_peg_arena build_peg_parser(const std::function & fn); diff --git a/common/unicode.cpp b/common/unicode.cpp new file mode 100644 index 00000000000..56ab0f468e0 --- /dev/null +++ b/common/unicode.cpp @@ -0,0 +1,64 @@ +#include "unicode.h" + +// implementation adopted from src/unicode.cpp + +size_t utf8_sequence_length(unsigned char first_byte) { + const size_t lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 }; + uint8_t highbits = static_cast(first_byte) >> 4; + return lookup[highbits]; +} + +utf8_parse_result parse_utf8_codepoint(std::string_view input, size_t offset) { + if (offset >= input.size()) { + return utf8_parse_result(utf8_parse_result::INCOMPLETE); + } + + // ASCII fast path + if (!(input[offset] & 0x80)) { + return utf8_parse_result(utf8_parse_result::SUCCESS, input[offset], 1); + } + + // Invalid: continuation byte as first byte + if (!(input[offset] & 0x40)) { + return utf8_parse_result(utf8_parse_result::INVALID); + } + + // 2-byte sequence + if (!(input[offset] & 0x20)) { + if (offset + 1 >= input.size()) { + return utf8_parse_result(utf8_parse_result::INCOMPLETE); + } + if ((input[offset + 1] & 0xc0) != 0x80) { + return utf8_parse_result(utf8_parse_result::INVALID); + } + auto result = ((input[offset] & 0x1f) << 6) | (input[offset + 1] & 0x3f); + return utf8_parse_result(utf8_parse_result::SUCCESS, result, 2); + } + + // 3-byte sequence + if (!(input[offset] & 0x10)) { + if (offset + 2 >= input.size()) { + return utf8_parse_result(utf8_parse_result::INCOMPLETE); + } + if ((input[offset + 1] & 0xc0) != 0x80 || (input[offset + 2] & 0xc0) != 0x80) { + return utf8_parse_result(utf8_parse_result::INVALID); + } + auto result = ((input[offset] & 0x0f) << 12) | ((input[offset + 1] & 0x3f) << 6) | (input[offset + 2] & 0x3f); + return utf8_parse_result(utf8_parse_result::SUCCESS, result, 3); + } + + // 4-byte sequence + if (!(input[offset] & 0x08)) { + if (offset + 3 >= input.size()) { + return utf8_parse_result(utf8_parse_result::INCOMPLETE); + } + if ((input[offset + 1] & 0xc0) != 0x80 || (input[offset + 2] & 0xc0) != 0x80 || (input[offset + 3] & 0xc0) != 0x80) { + return utf8_parse_result(utf8_parse_result::INVALID); + } + auto result = ((input[offset] & 0x07) << 18) | ((input[offset + 1] & 0x3f) << 12) | ((input[offset + 2] & 0x3f) << 6) | (input[offset + 3] & 0x3f); + return utf8_parse_result(utf8_parse_result::SUCCESS, result, 4); + } + + // Invalid first byte + return utf8_parse_result(utf8_parse_result::INVALID); +} diff --git a/common/unicode.h b/common/unicode.h new file mode 100644 index 00000000000..9d9e8e1227a --- /dev/null +++ b/common/unicode.h @@ -0,0 +1,22 @@ +#pragma once + +#include +#include + +// UTF-8 parsing utilities for streaming-aware unicode support + +struct utf8_parse_result { + uint32_t codepoint; // Decoded codepoint (only valid if status == SUCCESS) + size_t bytes_consumed; // How many bytes this codepoint uses (1-4) + enum status { SUCCESS, INCOMPLETE, INVALID } status; + + utf8_parse_result(enum status s, uint32_t cp = 0, size_t bytes = 0) + : codepoint(cp), bytes_consumed(bytes), status(s) {} +}; + +// Determine the expected length of a UTF-8 sequence from its first byte +// Returns 0 for invalid first bytes +size_t utf8_sequence_length(unsigned char first_byte); + +// Parse a single UTF-8 codepoint from input +utf8_parse_result parse_utf8_codepoint(std::string_view input, size_t offset); diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 8ddb6d04cd9..454d1f9badd 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -1603,7 +1603,10 @@ def __init__(self, *args, **kwargs): self.preprocessor_config = {**self.preprocessor_config, **cfg} def get_vision_config(self) -> dict[str, Any] | None: + config_name = "vision_config" if not self.is_mistral_format else "vision_encoder" + if self.hparams.get("architectures")[0] == "Phi3VForCausalLM": + config_name = "img_processor" return self.global_config.get(config_name) def get_audio_config(self) -> dict[str, Any] | None: @@ -4509,7 +4512,7 @@ def set_gguf_parameters(self): self.gguf_writer.add_add_bos_token(False) -@ModelBase.register("Phi3ForCausalLM") +@ModelBase.register("Phi3ForCausalLM", "Phi3VForCausalLM") class Phi3MiniModel(TextModel): model_arch = gguf.MODEL_ARCH.PHI3 @@ -4644,6 +4647,20 @@ def set_gguf_parameters(self): sliding_window = 0 self.gguf_writer.add_sliding_window(sliding_window) + def modify_tensors( + self, + data_torch: Tensor, + name: str, + bid: int | None, + ) -> Iterable[tuple[str, Tensor]]: + + VISION_PREFIX = "model.vision_embed_tokens." + + if name.startswith(VISION_PREFIX): + return + + yield from super().modify_tensors(data_torch, name, bid) + def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]: n_embd = self.find_hparam(["hidden_size", "n_embd"]) n_head = self.find_hparam(["num_attention_heads", "n_head"]) @@ -4684,7 +4701,6 @@ def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]: yield (self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FACTORS_LONG), torch.tensor(long_factors, dtype=torch.float32)) yield (self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FACTORS_SHORT), torch.tensor(short_factors, dtype=torch.float32)) - @ModelBase.register("PhiMoEForCausalLM") class PhiMoeModel(Phi3MiniModel): model_arch = gguf.MODEL_ARCH.PHIMOE @@ -10006,6 +10022,53 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter return [] # skip other tensors +@ModelBase.register("Phi3VForCausalLM") +class Phi3VisionModel(MmprojModel): + def set_gguf_parameters(self): + super().set_gguf_parameters() + self.gguf_writer.add_vision_attention_layernorm_eps(self.hparams.get("layer_norm_eps", 1e-6)) + self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.PHI3V) + self.gguf_writer.add_clip_has_llava_projector(True) + + num_img_tokens = self.find_vparam(["num_img_tokens"], optional=True) or 144 + self.gguf_writer.add_uint32("clip.vision.num_img_tokens", num_img_tokens) + + use_hd_tf = self.find_vparam(["use_hd_transform"], optional=True) + self.gguf_writer.add_bool("clip.vision.use_hd_transform", use_hd_tf if use_hd_tf is not None else True) + + with_sep = self.find_vparam(["with_learnable_separator"], optional=True) + self.gguf_writer.add_bool("clip.vision.with_learnable_separator", with_sep if with_sep is not None else True) + + hd_order = self.find_vparam(["hd_transform_order"], optional=True) + self.gguf_writer.add_string("clip.vision.hd_transform_order", hd_order or "sub_glb") + + img_dim_out = self.find_vparam(["image_dim_out", "dim_out"], optional=True) or self.find_vparam(["hidden_size"]) + self.gguf_writer.add_uint32("clip.vision.image_dim_out", img_dim_out) + + num_crops = self.find_vparam(["num_crops"], optional=True) or self.find_vparam(["num_crops"]) + self.gguf_writer.add_uint32("clip.vision.num_crops", num_crops) + + num_layers = self.find_vparam(["num_hidden_layers"], optional=True) or self.find_vparam(["num_hidden_layers"]) + self.gguf_writer.add_uint32("clip.vision.block_count", num_layers - 1) # Dropping the last (24th) layer + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + del bid # unused + + if not name.startswith("model.vision_embed_tokens."): + return [] + + # Phi-3 Vision uses a 24-layer SigLIP but usually drops the 24th layer (index 23). + if "img_processor.vision_model.encoder.layers" in name: + try: + parts = name.split('.') + layer_idx = int(parts[parts.index("layers") + 1]) + if layer_idx == self.block_count - 1: + return [] + except (ValueError, IndexError): + pass + + return [(self.map_tensor_name(name), data_torch)] + @ModelBase.register("CogVLMForCausalLM") class CogVLMVisionModel(MmprojModel): diff --git a/docs/build.md b/docs/build.md index 7d244ff013b..316269288df 100644 --- a/docs/build.md +++ b/docs/build.md @@ -431,11 +431,22 @@ docker run -it --rm -v "$(pwd):/app:Z" --device /dev/dri/renderD128:/dev/dri/ren ### For Linux users: +#### Using the LunarG Vulkan SDK + First, follow the official LunarG instructions for the installation and setup of the Vulkan SDK in the [Getting Started with the Linux Tarball Vulkan SDK](https://vulkan.lunarg.com/doc/sdk/latest/linux/getting_started.html) guide. > [!IMPORTANT] > After completing the first step, ensure that you have used the `source` command on the `setup_env.sh` file inside of the Vulkan SDK in your current terminal session. Otherwise, the build won't work. Additionally, if you close out of your terminal, you must perform this step again if you intend to perform a build. However, there are ways to make this persistent. Refer to the Vulkan SDK guide linked in the first step for more information about any of this. +#### Using system packages + +On Debian / Ubuntu, you can install the required dependencies using: +```sh +sudo apt-get install libvulkan-dev glslc +``` + +#### Common steps + Second, after verifying that you have followed all of the SDK installation/setup steps, use this command to make sure before proceeding: ```bash vulkaninfo diff --git a/docs/development/parsing.md b/docs/development/parsing.md new file mode 100644 index 00000000000..113ab2e2ee4 --- /dev/null +++ b/docs/development/parsing.md @@ -0,0 +1,288 @@ +# Parsing Model Output + +The `common` library contains a PEG parser implementation suitable for parsing +model output. + +Types with the prefix `common_peg_*` are intended for general use and may have +applications beyond parsing model output, such as parsing user-provided regex +patterns. + +Types with the prefix `common_chat_peg_*` are specialized helpers for model +output. + +The parser features: + +- Partial parsing of streaming input +- Built-in JSON parsers +- AST generation with semantics via "tagged" nodes + +## Example + +Below is a contrived example demonstrating how to use the PEG parser to parse +output from a model that emits arguments as JSON. + +```cpp +auto parser = build_chat_peg_native_parser([&](common_chat_peg_native_builder & p) { + // Build a choice of all available tools + auto tool_choice = p.choice(); + for (const auto & tool : tools) { + const auto & function = tool.at("function"); + std::string name = function.at("name"); + const auto & schema = function.at("parameters"); + + auto tool_name = p.json_member("name", "\"" + p.literal(name) + "\""); + auto tool_args = p.json_member("arguments", p.schema(p.json(), "tool-" + name + "-schema", schema)); + + tool_choice |= p.rule("tool-" + name, "{" << tool_name << "," << tool_args << "}"); + } + + // Define the tool call structure: [{tool}] + auto tool_call = p.trigger_rule("tool-call", + p.sequence({ + p.literal("["), + tool_choice, + p.literal("]") + }) + ); + + // Parser accepts content, optionally followed by a tool call + return p.sequence({ + p.content(p.until("")), + p.optional(tool_call), + p.end() + }); +}); +``` + +For a more complete example, see `test_example_native()` in +[tests/test-chat-peg-parser.cpp](tests/test-chat-peg-parser.cpp). + +## Parsers/Combinators + +### Basic Matchers + +- **`eps()`** - Matches nothing and always succeeds (epsilon/empty match) +- **`start()`** - Matches the start of input (anchor `^`) +- **`end()`** - Matches the end of input (anchor `$`) +- **`literal(string)`** - Matches an exact literal string +- **`any()`** - Matches any single character (`.`) + +### Combinators + +- **`sequence(...)`** - Matches parsers in order; all must succeed +- **`choice(...)`** - Matches the first parser that succeeds from alternatives (ordered choice) +- **`one_or_more(p)`** - Matches one or more repetitions (`+`) +- **`zero_or_more(p)`** - Matches zero or more repetitions (`*`) +- **`optional(p)`** - Matches zero or one occurrence (`?`) +- **`repeat(p, min, max)`** - Matches between min and max repetitions (use `-1` for unbounded) +- **`repeat(p, n)`** - Matches exactly n repetitions + +### Lookahead + +- **`peek(p)`** - Positive lookahead: succeeds if parser succeeds without consuming input (`&`) +- **`negate(p)`** - Negative lookahead: succeeds if parser fails without consuming input (`!`) + +### Character Classes & Utilities + +- **`chars(classes, min, max)`** - Matches repetitions of characters from a character class +- **`space()`** - Matches zero or more whitespace characters (space, tab, newline) +- **`until(delimiter)`** - Matches characters until delimiter is found (delimiter not consumed) +- **`until_one_of(delimiters)`** - Matches characters until any delimiter in the list is found +- **`rest()`** - Matches everything remaining (`.*`) + +### JSON Parsers + +- **`json()`** - Complete JSON parser (objects, arrays, strings, numbers, booleans, null) +- **`json_object()`** - JSON object parser +- **`json_array()`** - JSON array parser +- **`json_string()`** - JSON string parser +- **`json_number()`** - JSON number parser +- **`json_bool()`** - JSON boolean parser +- **`json_null()`** - JSON null parser +- **`json_string_content()`** - JSON string content without surrounding quotes +- **`json_member(key, p)`** - JSON object member with specific key and value parser + +### Grammar Building + +- **`ref(name)`** - Creates a lightweight reference to a named rule (for recursive grammars) +- **`rule(name, p, trigger)`** - Creates a named rule and returns a reference +- **`trigger_rule(name, p)`** - Creates a trigger rule (entry point for lazy grammar generation) +- **`schema(p, name, schema, raw)`** - Wraps parser with JSON schema metadata for grammar generation + +### AST Control + +- **`atomic(p)`** - Prevents AST node creation for partial parses +- **`tag(tag, p)`** - Creates AST nodes with semantic tags (multiple nodes can share tags) + +## GBNF Grammar Generation + +The PEG parser also acts as a convenient DSL for generating GBNF grammars, with +some exceptions. + +```cpp +data.grammar = build_grammar([&](const common_grammar_builder & builder) { + foreach_function(params.tools, [&](const json & fn) { + builder.resolve_refs(fn.at("parameters")); + }); + parser.build_grammar(builder, data.grammar_lazy); +}); +``` + +The notable exception is the `negate(p)` lookahead parser, which cannot be +defined as a CFG grammar and therefore does not produce a rule. Its usage +should be limited and preferably hidden behind a `schema()` parser. In many +cases, `until(delimiter)` or `until_one_of(delimiters)` is a better choice. + +Another limitation is that the PEG parser requires an unambiguous grammar. In +contrast, the `llama-grammar` implementation can support ambiguous grammars, +though they are difficult to parse. + +### Lazy Grammars + +During lazy grammar generation, only rules reachable from a `trigger_rule(p)` +are emitted in the grammar. All trigger rules are added as alternations in the +root rule. It is still necessary to define trigger patterns, as the parser has +no interaction with the grammar sampling. + +### JSON Schema + +The `schema(p, name, schema, raw)` parser will use the `json-schema-to-grammar` +implementation to generate the grammar instead of the underlying parser. + +The `raw` option emits a grammar suitable for a raw string instead of a JSON +string. In other words, it won't be wrapped in quotes or require escaping +quotes. It should only be used when `type == "string"`. + +The downside is that it can potentially lead to ambiguous grammars. For +example, if a user provides the pattern `^.*$`, the following grammar may be +generated: + +``` +root ::= "" .* "" +``` + +This creates an ambiguous grammar that cannot be parsed by the PEG parser. To +help mitigate this, if `.*` is found in the pattern, the grammar from the +underlying parser will be emitted instead. + +## Common AST Shapes for Chat Parsing + +Most model output can be placed in one of the following categories: + +- Content only +- Tool calling with arguments emitted as a single JSON object +- Tool calling with arguments emitted as separate entities, either XML + (Qwen3-Coder, MiniMax M2) or pseudo-function calls (LFM2) + +To provide broad coverage, +[`common/chat-peg-parser.h`](common/chat-peg-parser.h) contains builders and +mappers that help create parsers and visitors/extractors for these types. They +require parsers to tag nodes to conform to an AST "shape". This normalization +makes it easy to extract information and generalize parsing. + +### Simple + +The `common_chat_peg_builder` builds a `simple` parser that supports +content-only models with optional reasoning. + +- **`reasoning(p)`** - Tag node for extracting `reasoning_content` +- **`content(p)`** - Tag node for extracting `content` + +```cpp +build_chat_peg_parser([&](common_chat_peg_parser & p) { + return p.sequence({ + p.optional("" + p.reasoning(p.until("")) + ""), + p.content(p.until("")), + p.end() + }); +}); +``` + +Use `common_chat_peg_mapper` to extract the content. Note that this is already +done for you in `common_chat_peg_parser` when +`chat_format == COMMON_CHAT_FORMAT_PEG_SIMPLE`. + +```cpp +auto result = parser.parse(ctx); + +common_chat_msg msg; +auto mapper = common_chat_peg_mapper(msg); +mapper.from_ast(ctx.ast, result); +``` + +### Native + +The `common_chat_peg_native_builder` builds a `native` parser suitable for +models that emit tool arguments as a direct JSON object. + +- **`reasoning(p)`** - Tag node for `reasoning_content` +- **`content(p)`** - Tag node for `content` +- **`tool(p)`** - Tag entirety of a single tool call +- **`tool_open(p)`** - Tag start of a tool call +- **`tool_close(p)`** - Tag end of a tool call +- **`tool_id(p)`** - Tag the tool call ID (optional) +- **`tool_name(p)`** - Tag the tool name +- **`tool_args(p)`** - Tag the tool arguments + +```cpp +build_chat_peg_native_parser([&](common_chat_peg_native_parser & p) { + auto get_weather_tool = p.tool(p.sequence({ + p.tool_open(p.literal("{")), + p.json_member("name", "\"" + p.tool_name(p.literal("get_weather")) + "\""), + p.literal(","), + p.json_member("arguments", p.tool_args(p.json())), + p.tool_close(p.literal("}")) + })); + + return p.sequence({ + p.content(p.until("")), + p.literal(""), + get_weather_tool, + p.literal(""), + p.end() + }); +}); +``` + +### Constructed + +The `common_chat_peg_constructed_builder` builds a `constructed` parser +suitable for models that emit tool arguments as separate entities, such as XML +tags. + +- **`reasoning(p)`** - Tag node for `reasoning_content` +- **`content(p)`** - Tag node for `content` +- **`tool(p)`** - Tag entirety of a single tool call +- **`tool_open(p)`** - Tag start of a tool call +- **`tool_close(p)`** - Tag end of a tool call +- **`tool_name(p)`** - Tag the tool name +- **`tool_arg(p)`** - Tag a complete tool argument (name + value) +- **`tool_arg_open(p)`** - Tag start of a tool argument +- **`tool_arg_close(p)`** - Tag end of a tool argument +- **`tool_arg_name(p)`** - Tag the argument name +- **`tool_arg_string_value(p)`** - Tag string value for the argument +- **`tool_arg_json_value(p)`** - Tag JSON value for the argument + +```cpp +build_chat_peg_constructed_parser([&](common_chat_peg_constructed_builder & p) { + auto location_arg = p.tool_arg( + p.tool_arg_open(""), + p.tool_arg_string_value(p.until("")), + p.tool_arg_close(p.literal("")) + ); + + auto get_weather_tool = p.tool(p.sequence({ + p.tool_open(""), + location_arg, + p.tool_close(p.literal("")) + })); + + return p.sequence({ + p.content(p.until("")), + p.literal(""), + get_weather_tool, + p.literal(""), + p.end() + }); +}); +``` diff --git a/ggml/CMakeLists.txt b/ggml/CMakeLists.txt index 9b10df00dae..0ccd901921d 100644 --- a/ggml/CMakeLists.txt +++ b/ggml/CMakeLists.txt @@ -175,11 +175,6 @@ option(GGML_CPU_ALL_VARIANTS "ggml: build all variants of the CPU backend (requi set(GGML_CPU_ARM_ARCH "" CACHE STRING "ggml: CPU architecture for ARM") set(GGML_CPU_POWERPC_CPUTYPE "" CACHE STRING "ggml: CPU type for PowerPC") - -if (MINGW) - set(GGML_WIN_VER "0xA00" CACHE STRING "ggml: Windows version") -endif() - # ggml core set(GGML_SCHED_MAX_COPIES "4" CACHE STRING "ggml: max input copies for pipeline parallelism") option(GGML_CPU "ggml: enable CPU backend" ON) @@ -226,7 +221,7 @@ option(GGML_WEBGPU "ggml: use WebGPU" option(GGML_WEBGPU_DEBUG "ggml: enable WebGPU debug output" OFF) option(GGML_WEBGPU_CPU_PROFILE "ggml: enable WebGPU profiling (CPU)" OFF) option(GGML_WEBGPU_GPU_PROFILE "ggml: enable WebGPU profiling (GPU)" OFF) - +option(GGML_WEBGPU_JSPI "ggml: use JSPI for WebGPU" ON) option(GGML_ZDNN "ggml: use zDNN" OFF) option(GGML_METAL "ggml: use Metal" ${GGML_METAL_DEFAULT}) option(GGML_METAL_NDEBUG "ggml: disable Metal debugging" OFF) @@ -408,62 +403,67 @@ if (MSVC) /wd4996 # Disable POSIX deprecation warnings /wd4702 # Unreachable code warnings ) - function(disable_msvc_warnings target_name) + set(MSVC_COMPILE_OPTIONS + "$<$:/utf-8>" + "$<$:/utf-8>" + ) + function(configure_msvc_target target_name) if(TARGET ${target_name}) target_compile_options(${target_name} PRIVATE ${MSVC_WARNING_FLAGS}) + target_compile_options(${target_name} PRIVATE ${MSVC_COMPILE_OPTIONS}) endif() endfunction() - disable_msvc_warnings(ggml-base) - disable_msvc_warnings(ggml) - disable_msvc_warnings(ggml-cpu) - disable_msvc_warnings(ggml-cpu-x64) - disable_msvc_warnings(ggml-cpu-sse42) - disable_msvc_warnings(ggml-cpu-sandybridge) - disable_msvc_warnings(ggml-cpu-haswell) - disable_msvc_warnings(ggml-cpu-skylakex) - disable_msvc_warnings(ggml-cpu-icelake) - disable_msvc_warnings(ggml-cpu-alderlake) + configure_msvc_target(ggml-base) + configure_msvc_target(ggml) + configure_msvc_target(ggml-cpu) + configure_msvc_target(ggml-cpu-x64) + configure_msvc_target(ggml-cpu-sse42) + configure_msvc_target(ggml-cpu-sandybridge) + configure_msvc_target(ggml-cpu-haswell) + configure_msvc_target(ggml-cpu-skylakex) + configure_msvc_target(ggml-cpu-icelake) + configure_msvc_target(ggml-cpu-alderlake) if (GGML_BUILD_EXAMPLES) - disable_msvc_warnings(common-ggml) - disable_msvc_warnings(common) + configure_msvc_target(common-ggml) + configure_msvc_target(common) - disable_msvc_warnings(mnist-common) - disable_msvc_warnings(mnist-eval) - disable_msvc_warnings(mnist-train) + configure_msvc_target(mnist-common) + configure_msvc_target(mnist-eval) + configure_msvc_target(mnist-train) - disable_msvc_warnings(gpt-2-ctx) - disable_msvc_warnings(gpt-2-alloc) - disable_msvc_warnings(gpt-2-backend) - disable_msvc_warnings(gpt-2-sched) - disable_msvc_warnings(gpt-2-quantize) - disable_msvc_warnings(gpt-2-batched) + configure_msvc_target(gpt-2-ctx) + configure_msvc_target(gpt-2-alloc) + configure_msvc_target(gpt-2-backend) + configure_msvc_target(gpt-2-sched) + configure_msvc_target(gpt-2-quantize) + configure_msvc_target(gpt-2-batched) - disable_msvc_warnings(gpt-j) - disable_msvc_warnings(gpt-j-quantize) + configure_msvc_target(gpt-j) + configure_msvc_target(gpt-j-quantize) - disable_msvc_warnings(magika) - disable_msvc_warnings(yolov3-tiny) - disable_msvc_warnings(sam) + configure_msvc_target(magika) + configure_msvc_target(yolov3-tiny) + configure_msvc_target(sam) - disable_msvc_warnings(simple-ctx) - disable_msvc_warnings(simple-backend) + configure_msvc_target(simple-ctx) + configure_msvc_target(simple-backend) endif() if (GGML_BUILD_TESTS) - disable_msvc_warnings(test-mul-mat) - disable_msvc_warnings(test-arange) - disable_msvc_warnings(test-backend-ops) - disable_msvc_warnings(test-cont) - disable_msvc_warnings(test-conv-transpose) - disable_msvc_warnings(test-conv-transpose-1d) - disable_msvc_warnings(test-conv1d) - disable_msvc_warnings(test-conv2d) - disable_msvc_warnings(test-conv2d-dw) - disable_msvc_warnings(test-customop) - disable_msvc_warnings(test-dup) - disable_msvc_warnings(test-opt) - disable_msvc_warnings(test-pool) + configure_msvc_target(test-mul-mat) + configure_msvc_target(test-arange) + configure_msvc_target(test-backend-ops) + configure_msvc_target(test-cont) + configure_msvc_target(test-conv-transpose) + configure_msvc_target(test-conv-transpose-1d) + configure_msvc_target(test-conv1d) + configure_msvc_target(test-conv2d) + configure_msvc_target(test-conv2d-dw) + configure_msvc_target(test-customop) + configure_msvc_target(test-dup) + configure_msvc_target(test-opt) + configure_msvc_target(test-pool) endif () endif() diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 48da68fe7e3..b0e10f57685 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -204,6 +204,10 @@ # define GGML_ATTRIBUTE_FORMAT(...) __attribute__((format(printf, __VA_ARGS__))) #endif +#if defined(_WIN32) && !defined(_WIN32_WINNT) +# define _WIN32_WINNT 0x0A00 +#endif + #include #include #include @@ -2279,7 +2283,7 @@ extern "C" { float stop, float step); -#define GGML_KQ_MASK_PAD 64 +#define GGML_KQ_MASK_PAD 1 // q: [n_embd_k, n_batch, n_head, ne3 ] // k: [n_embd_k, n_kv, n_head_kv, ne3 ] diff --git a/ggml/src/CMakeLists.txt b/ggml/src/CMakeLists.txt index d93664b8b58..98606e9cf18 100644 --- a/ggml/src/CMakeLists.txt +++ b/ggml/src/CMakeLists.txt @@ -127,10 +127,6 @@ if (NOT MSVC) endif() endif() -if (MINGW) - add_compile_definitions(_WIN32_WINNT=${GGML_WIN_VER}) -endif() - # # POSIX conformance # diff --git a/ggml/src/ggml-backend.cpp b/ggml/src/ggml-backend.cpp index 1d88c826bb1..08681f35e3f 100644 --- a/ggml/src/ggml-backend.cpp +++ b/ggml/src/ggml-backend.cpp @@ -1240,10 +1240,8 @@ void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct ggml_cgra tensor_copy = ggml_dup_tensor_layout(sched->ctx, src); ggml_format_name(tensor_copy, "%s#%s#%d", ggml_backend_name(backend), src->name, c); } - if (sched->n_copies > 1) { - ggml_set_input(tensor_copy); - ggml_set_output(tensor_copy); // prevent ggml-alloc from overwriting the tensor - } + ggml_set_input(tensor_copy); + ggml_set_output(tensor_copy); // prevent ggml-alloc from overwriting the tensor tensor_id_copy(src_id, src_backend_id, c) = tensor_copy; SET_CAUSE(tensor_copy, "4.cpy"); } diff --git a/ggml/src/ggml-cann/ggml-cann.cpp b/ggml/src/ggml-cann/ggml-cann.cpp index cd1b5e5b944..544c1e2a501 100644 --- a/ggml/src/ggml-cann/ggml-cann.cpp +++ b/ggml/src/ggml-cann/ggml-cann.cpp @@ -2564,6 +2564,10 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, const ggml_ten return true; case GGML_OP_OUT_PROD: { +#ifdef ASCEND_310P + // Ger is not supported on 310p device + return false; +#endif switch (op->src[0]->type) { case GGML_TYPE_F16: case GGML_TYPE_F32: diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index 3247af8bb03..8507557267a 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -683,22 +683,14 @@ bool ggml_is_numa(void) { } #if defined(__ARM_ARCH) - -#if defined(__linux__) && defined(__aarch64__) -#include -#endif - -static void ggml_init_arm_arch_features(void) { #if defined(__aarch64__) && defined(__ARM_FEATURE_SVE) -#if defined(__linux__) - ggml_arm_arch_features.sve_cnt = PR_SVE_VL_LEN_MASK & prctl(PR_SVE_GET_VL); +#include +static void ggml_init_arm_arch_features(void) { + ggml_arm_arch_features.sve_cnt = svcntb(); +} #else - // TODO: add support of SVE for non-linux systems -#error "TODO: SVE is not supported on this platform. To use SVE, sve_cnt needs to be initialized here." +static void ggml_init_arm_arch_features(void) {} #endif -#endif -} - #endif // __ARM_ARCH struct ggml_tensor * ggml_new_i32(struct ggml_context * ctx, int32_t value) { @@ -2706,6 +2698,11 @@ struct ggml_cplan ggml_graph_plan( n_threads = threadpool ? threadpool->n_threads_max : GGML_DEFAULT_N_THREADS; } +#if defined(__EMSCRIPTEN__) && !defined(__EMSCRIPTEN_PTHREADS__) + // Emscripten without pthreads support can only use a single thread + n_threads = 1; +#endif + size_t work_size = 0; struct ggml_cplan cplan; diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index 608e82af69f..ac16b3681b7 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -6383,7 +6383,7 @@ static void ggml_compute_forward_im2col_3d_f16( const int64_t iih = ioh*s1 + ikh*d1 - p1; const int64_t iid = iod*s2 + ikd*d2 - p2; - if (iid < 0 || iid >= ID || iih < 0 || iih >= IH || iiw < 0 || iiw >= IW || iid < 0 || iid >= ID) { + if (iid < 0 || iid >= ID || iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) { dst_data[iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw] = 0; } else { const float * const s = (const float *) ((const char *)src_data + iid*nb12 + iih*nb11 + iiw*nb10); // [ID, IH, IW] diff --git a/ggml/src/ggml-cuda/fattn-common.cuh b/ggml/src/ggml-cuda/fattn-common.cuh index 5cdd4bb2114..02443b8c638 100644 --- a/ggml/src/ggml-cuda/fattn-common.cuh +++ b/ggml/src/ggml-cuda/fattn-common.cuh @@ -25,7 +25,7 @@ typedef void (* fattn_kernel_t)( const float m1, const uint32_t n_head_log2, const float logit_softcap, - const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03, + const int32_t ne00, const uint3 ne01, const int32_t ne02, const int32_t ne03, const int32_t nb01, const int32_t nb02, const int32_t nb03, const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13, const int32_t nb11, const int32_t nb12, const int64_t nb13, @@ -621,7 +621,8 @@ static __global__ void flash_attn_mask_to_KV_max( template // D == head size __launch_bounds__(D, 1) static __global__ void flash_attn_stream_k_fixup( - float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne03, const int ne11) { + float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne03, const int ne11, + const int nbatch_fa) { constexpr int ncols = ncols1*ncols2; const int bidx0 = blockIdx.x; @@ -632,8 +633,8 @@ static __global__ void flash_attn_stream_k_fixup( const float * dst_fixup_data = ((const float *) dst_fixup) + gridDim.x*(2*2*ncols); - const int iter_k = ne11 / FATTN_KQ_STRIDE; - const int iter_j = (ne01 + (ncols1 - 1)) / ncols1; + const int iter_k = (ne11 + (nbatch_fa - 1)) / nbatch_fa; + const int iter_j = (ne01 + (ncols1 - 1)) / ncols1; const int kbc0 = (bidx0 + 0)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x; const int kbc0_stop = (bidx0 + 1)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x; @@ -765,7 +766,7 @@ static __global__ void flash_attn_combine_results( template void launch_fattn( ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel, const int nwarps, const size_t nbytes_shared, - const int KQ_row_granularity, const bool need_f16_K, const bool need_f16_V, const bool stream_k, const int warp_size = WARP_SIZE + const int nbatch_fa, const bool need_f16_K, const bool need_f16_V, const bool stream_k, const int warp_size = WARP_SIZE ) { constexpr int ncols = ncols1 * ncols2; @@ -790,8 +791,6 @@ void launch_fattn( GGML_ASSERT(!V || V->nb[0] == ggml_element_size(V)); GGML_ASSERT(!mask || mask->type == GGML_TYPE_F16); - GGML_ASSERT(!mask || mask->ne[1] >= GGML_PAD(Q->ne[1], 16) && - "the Flash-Attention CUDA kernel requires the mask to be padded to 16 and at least n_queries big"); ggml_cuda_pool & pool = ctx.pool(); cudaStream_t main_stream = ctx.stream(); @@ -915,7 +914,7 @@ void launch_fattn( dst_tmp_meta.alloc(blocks_num.x*ncols * (2*2 + DV) * sizeof(float)); } else { - const int ntiles_KQ = (K->ne[1] + KQ_row_granularity - 1) / KQ_row_granularity; // Max. number of parallel blocks limited by tensor size. + const int ntiles_KQ = (K->ne[1] + nbatch_fa - 1) / nbatch_fa; // Max. number of parallel blocks limited by tensor size. // parallel_blocks must not be larger than what the tensor size allows: parallel_blocks = std::min(parallel_blocks, ntiles_KQ); @@ -970,6 +969,9 @@ void launch_fattn( const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); + // TODO other tensor dimensions after removal of WMMA kernel: + const uint3 ne01 = init_fastdiv_values(Q->ne[1]); + GGML_ASSERT(block_dim.x % warp_size == 0); fattn_kernel<<>>( (const char *) Q->data, @@ -980,7 +982,7 @@ void launch_fattn( KV_max.ptr, !stream_k && parallel_blocks > 1 ? dst_tmp.ptr : (float *) KQV->data, dst_tmp_meta.ptr, scale, max_bias, m0, m1, n_head_log2, logit_softcap, - Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], Q->nb[1], Q->nb[2], Q->nb[3], + Q->ne[0], ne01, Q->ne[2], Q->ne[3], Q->nb[1], Q->nb[2], Q->nb[3], K->ne[0], K->ne[1], K->ne[2], K->ne[3], nb11, nb12, nb13, nb21, nb22, nb23, mask ? mask->ne[1] : 0, mask ? mask->ne[2] : 0, mask ? mask->ne[3] : 0, @@ -995,7 +997,7 @@ void launch_fattn( flash_attn_stream_k_fixup <<>> - ((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], Q->ne[3], K->ne[1]); + ((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], Q->ne[3], K->ne[1], nbatch_fa); } } else if (parallel_blocks > 1) { const dim3 block_dim_combine(DV, 1, 1); diff --git a/ggml/src/ggml-cuda/fattn-mma-f16.cuh b/ggml/src/ggml-cuda/fattn-mma-f16.cuh index 57defb0c629..b6250cf7949 100644 --- a/ggml/src/ggml-cuda/fattn-mma-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-mma-f16.cuh @@ -5,284 +5,211 @@ using namespace ggml_cuda_mma; -typedef tile<16, 8, half2> tile_A; -typedef tile< 8, 8, half2> tile_B; -typedef tile<16, 8, half2> tile_B_16; -typedef tile<16, 8, float> tile_C_KQ; -typedef tile<16, 16, float> tile_C_KQ_16; -typedef tile<16, 4, half2> tile_C_VKQ; -typedef tile<16, 8, half2> tile_C_VKQ_16; - -// Config options for specific head sizes. +// Config options for the MMA kernel. // Should not affect results, only speed/register pressure/shared memory use. -// -// nbatch_fa: number of KV rows per softmax rescaling of KQ rowsums and VKQ accumulators. -// nwarps_max: maximum number of warps per CUDA block, up to 8 warps in total can run per SM (given enough shared memory). -// Q_in_reg: whether the Q values should be kept permanently in registers. -// nstages_target: targeted number of pipeline stages for cp_async (if available), 0 means synchronous data loading. -// nbatch_K2: number of K half2 values in direction of DKQ to load in parallel. -// nbatch_V2: number of V half2 values in direction of DV to load in parallel. -// nbatch_combine: number of VKQ half2 values in direction of DV to combine in parallel. - -template -struct fattn_mma_f16_config; - -template <> -struct fattn_mma_f16_config< 64, 64> { - static constexpr int nbatch_fa = 64; - static constexpr int nwarps_max = 4; - static constexpr bool Q_in_reg = true; - static constexpr int nstages_target = 2; - - static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) { - return 32; - } - - static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) { - return 32; - } - - static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) { - return 32; - } - - static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) { - return 32; - } - - static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) { - return 32; - } - - static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) { - return 32; - } +struct fattn_mma_config { + int nthreads; // Number of threads per CUDA block. + int occupancy; // Targeted occupancy for the MMA kernel. + int nbatch_fa; // Number of KV rows per softmax rescaling of KQ rowsums and VKQ accumulators. + int nbatch_K2; // Number of K half2 values in direction of DKQ to load in parallel. + int nbatch_V2; // Number of V half2 values in direction of DV to load in parallel. + int nbatch_combine; // Number of VKQ half2 values in direction of DV to combine in parallel. + int nstages_target; // Number of pipeline stages to use ideally, 1 == always load data synchronously, 2 == preload data if there is hardware support. + bool Q_in_reg; // Whether the Q values should be kept permanently in registers. + + constexpr __host__ __device__ fattn_mma_config( + int nthreads, int occupancy, int nbatch_fa, int nbatch_K2, int nbatch_V2, int nbatch_combine, int nstages_target, bool Q_in_reg) : + nthreads(nthreads), occupancy(occupancy), nbatch_fa(nbatch_fa), nbatch_K2(nbatch_K2), nbatch_V2(nbatch_V2), nbatch_combine(nbatch_combine), + nstages_target(nstages_target), Q_in_reg(Q_in_reg) {} }; -template <> -struct fattn_mma_f16_config< 80, 80> { - static constexpr int nbatch_fa = 64; - static constexpr int nwarps_max = 4; - static constexpr bool Q_in_reg = true; - static constexpr int nstages_target = 2; - - static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) { - return 40; - } - - static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) { - return 40; - } - - static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) { - return 40; - } - - static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) { - return 40; - } - - static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) { - return 40; - } - - static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) { - return 40; - } -}; - -template <> -struct fattn_mma_f16_config< 96, 96> { - static constexpr int nbatch_fa = 64; - static constexpr int nwarps_max = 4; - static constexpr bool Q_in_reg = true; - static constexpr int nstages_target = 2; - - static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) { - return 48; - } - - static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) { - return 48; - } - - static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) { - return 48; - } - - static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) { - return 48; - } - - static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) { - return 48; - } +#define GGML_CUDA_FATTN_MMA_CONFIG_CASE(DKQ_, DV_, ncols_, nthreads_, occupancy_, nbatch_fa_, nbatch_K2_, nbatch_V2_, nbatch_combine_, nstages_target_, Q_in_reg_) \ + if (DKQ == (DKQ_) && DV == (DV_) && ncols == (ncols_)) { \ + static_assert((nthreads_) % 32 == 0 && (nthreads_) <= 512, "bad nthreads"); \ + static_assert( (occupancy_) <= 8, "bad occupancy"); \ + static_assert((nbatch_fa_) % 32 == 0 && (nbatch_fa_) <= 256, "bad nbatch_fa"); \ + static_assert((nbatch_K2_) % 4 == 0 && (nbatch_K2_) <= 512, "bad nbatch_K2"); \ + static_assert((nbatch_V2_) % 4 == 0 && (nbatch_V2_) <= 256, "bad nbatch_V2"); \ + static_assert((nbatch_combine_) % 4 == 0 && (nbatch_combine_) <= 128, "bad nbatch_combine"); \ + static_assert((nstages_target_) >= 1 && (nstages_target_) <= 2, "bad nstages_target"); \ + return fattn_mma_config{(nthreads_), (occupancy_), (nbatch_fa_), (nbatch_K2_), (nbatch_V2_), (nbatch_combine_), (nstages_target_), (Q_in_reg_)}; \ + } \ + +static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_config_ampere(const int DKQ, const int DV, const int ncols) { + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64, 64, 8, 128, 2, 128, 32, 32, 32, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64, 64, 16, 128, 2, 64, 32, 32, 32, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64, 64, 32, 128, 2, 64, 32, 32, 32, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64, 64, 64, 128, 2, 64, 32, 32, 32, 2, true); + + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80, 80, 8, 128, 2, 128, 40, 40, 40, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80, 80, 16, 128, 2, 64, 40, 40, 40, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80, 80, 32, 128, 2, 64, 40, 40, 40, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80, 80, 64, 128, 2, 64, 40, 40, 40, 2, true); + + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96, 96, 8, 128, 2, 128, 48, 48, 48, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96, 96, 16, 128, 2, 64, 48, 48, 48, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96, 96, 32, 128, 2, 64, 48, 48, 48, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96, 96, 64, 128, 2, 64, 48, 48, 48, 2, true); + + GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 8, 128, 2, 128, 56, 56, 56, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 16, 128, 2, 64, 56, 56, 56, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 32, 128, 2, 64, 56, 56, 56, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 64, 128, 2, 64, 56, 56, 56, 2, true); + + GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 8, 128, 2, 128, 64, 64, 64, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 16, 128, 2, 64, 64, 64, 64, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 32, 128, 2, 64, 64, 64, 64, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 64, 128, 2, 64, 64, 64, 64, 2, true); + + GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 8, 64, 4, 64, 128, 128, 128, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 16, 64, 4, 32, 128, 128, 128, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 32, 128, 128, 128, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 128, 2, 32, 128, 128, 128, 2, true); + + GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 288, 256, 128, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 64, 4, 32, 288, 256, 128, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2, 32, 160, 128, 128, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 64, 256, 1, 32, 160, 128, 128, 1, false); + + return fattn_mma_config(32, 1, 0, 0, 0, 0, 0, false); +} - static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) { - return 48; - } -}; +static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_config_turing(const int DKQ, const int DV, const int ncols) { + GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 8, 128, 2, 64, 128, 128, 128, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 16, 128, 2, 64, 128, 128, 128, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 64, 128, 128, 64, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 128, 2, 64, 128, 128, 64, 2, true); -template <> -struct fattn_mma_f16_config<112, 112> { - static constexpr int nbatch_fa = 64; - static constexpr int nwarps_max = 4; - static constexpr bool Q_in_reg = true; - static constexpr int nstages_target = 2; + GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 96, 64, 128, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 64, 4, 32, 96, 64, 128, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2, 32, 160, 128, 128, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 64, 256, 1, 32, 160, 128, 128, 1, false); - static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) { - return 56; - } + return ggml_cuda_fattn_mma_get_config_ampere(DKQ, DV, ncols); +} - static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) { - return 56; - } +static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_config_volta(const int DKQ, const int DV, const int ncols) { + GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 288, 256, 64, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 64, 4, 32, 288, 256, 64, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2, 32, 160, 128, 64, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 64, 256, 1, 32, 160, 128, 64, 1, false); - static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) { - return 56; - } + // TODO tune specifically for Volta + return ggml_cuda_fattn_mma_get_config_ampere(DKQ, DV, ncols); +} - static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) { - return 56; +static __host__ fattn_mma_config ggml_cuda_fattn_mma_get_config(const int DKQ, const int DV, const int ncols, const int cc) { + if (ampere_mma_available(cc)) { + return ggml_cuda_fattn_mma_get_config_ampere(DKQ, DV, ncols); } - - static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) { - return 56; + if (turing_mma_available(cc)) { + return ggml_cuda_fattn_mma_get_config_turing(DKQ, DV, ncols); } + GGML_ASSERT(volta_mma_available(cc)); + return ggml_cuda_fattn_mma_get_config_volta(DKQ, DV, ncols); +} - static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) { - return 56; - } -}; +static constexpr __device__ fattn_mma_config ggml_cuda_fattn_mma_get_config(const int DKQ, const int DV, const int ncols) { +#if defined(AMPERE_MMA_AVAILABLE) + return ggml_cuda_fattn_mma_get_config_ampere(DKQ, DV, ncols); +#elif defined(TURING_MMA_AVAILABLE) + return ggml_cuda_fattn_mma_get_config_turing(DKQ, DV, ncols); +#elif defined(VOLTA_MMA_AVAILABLE) + return ggml_cuda_fattn_mma_get_config_volta(DKQ, DV, ncols); +#else + GGML_UNUSED_VARS(DKQ, DV, ncols); + return fattn_mma_config(32, 1, 0, 0, 0, 0, 0, false); +#endif // defined(AMPERE_MMA_AVAILABLE) +} -template <> -struct fattn_mma_f16_config<128, 128> { - static constexpr int nbatch_fa = 64; - static constexpr int nwarps_max = 4; - static constexpr bool Q_in_reg = true; - static constexpr int nstages_target = 2; +static __host__ int ggml_cuda_fattn_mma_get_nthreads(const int DKQ, const int DV, const int ncols, const int cc) { + return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols, cc).nthreads; +} - static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) { - return 64; - } +static constexpr __device__ int ggml_cuda_fattn_mma_get_nthreads(const int DKQ, const int DV, const int ncols) { + return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols).nthreads; +} - static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) { - return 64; - } +static __host__ int ggml_cuda_fattn_mma_get_occupancy(const int DKQ, const int DV, const int ncols, const int cc) { + return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols, cc).occupancy; +} - static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) { - return 64; - } +static constexpr __device__ int ggml_cuda_fattn_mma_get_occupancy(const int DKQ, const int DV, const int ncols) { + return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols).occupancy; +} - static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) { - return 64; - } +static __host__ int ggml_cuda_fattn_mma_get_nbatch_fa(const int DKQ, const int DV, const int ncols, const int cc) { + return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols, cc).nbatch_fa; +} - static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) { - return 64; - } +static constexpr __device__ int ggml_cuda_fattn_mma_get_nbatch_fa(const int DKQ, const int DV, const int ncols) { + return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols).nbatch_fa; +} - static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) { - return 64; - } -}; +static __host__ int ggml_cuda_fattn_mma_get_nbatch_K2(const int DKQ, const int DV, const int ncols, const int cc) { + return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols, cc).nbatch_K2; +} -template <> -struct fattn_mma_f16_config<256, 256> { - static constexpr int nbatch_fa = 32; - static constexpr int nwarps_max = 4; - static constexpr bool Q_in_reg = true; - static constexpr int nstages_target = 2; +static constexpr __device__ int ggml_cuda_fattn_mma_get_nbatch_K2(const int DKQ, const int DV, const int ncols) { + return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols).nbatch_K2; +} - static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) { - return 128; - } +static __host__ int ggml_cuda_fattn_mma_get_nbatch_V2(const int DKQ, const int DV, const int ncols, const int cc) { + return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols, cc).nbatch_V2; +} - static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) { - return 128; - } +static constexpr __device__ int ggml_cuda_fattn_mma_get_nbatch_V2(const int DKQ, const int DV, const int ncols) { + return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols).nbatch_V2; +} - static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) { - return 128; - } +static __host__ int ggml_cuda_fattn_mma_get_nbatch_combine(const int DKQ, const int DV, const int ncols, const int cc) { + return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols, cc).nbatch_combine; +} - static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) { - return 128; - } +static constexpr __device__ int ggml_cuda_fattn_mma_get_nbatch_combine(const int DKQ, const int DV, const int ncols) { + return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols).nbatch_combine; +} - static int get_nbatch_combine_host(const int cc, const int ncols) { - if (ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_TURING) { - return ncols <= 16 ? 128 : 64; - } - return 64; - } +static __host__ int ggml_cuda_fattn_mma_get_nstages_target(const int DKQ, const int DV, const int ncols, const int cc) { + return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols, cc).nstages_target; +} - static constexpr __device__ int get_nbatch_combine_device(int ncols) { -#if __CUDA_ARCH__ == GGML_CUDA_CC_TURING - return ncols <= 16 ? 128 : 64; -#else - GGML_UNUSED(ncols); - return 128; -#endif // __CUDA_ARCH__ == GGML_CUDA_CC_TURING - } -}; +static constexpr __device__ int ggml_cuda_fattn_mma_get_nstages_target(const int DKQ, const int DV, const int ncols) { + return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols).nstages_target; +} -template <> -struct fattn_mma_f16_config<576, 512> { - static constexpr int nbatch_fa = 32; - static constexpr int nwarps_max = 8; - static constexpr bool Q_in_reg = false; - static constexpr int nstages_target = 1; +static __host__ bool ggml_cuda_fattn_mma_get_Q_in_reg(const int DKQ, const int DV, const int ncols, const int cc) { + return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols, cc).Q_in_reg; +} - static int get_nbatch_K2_host(const int cc, const int ncols) { - if (ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_TURING) { - return ncols <= 16 ? 96 : 160; - } - return ncols <= 16 ? 288 : 160; - } +static constexpr __device__ bool ggml_cuda_fattn_mma_get_Q_in_reg(const int DKQ, const int DV, const int ncols) { + return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols).Q_in_reg; +} - static constexpr __device__ int get_nbatch_K2_device(int ncols) { -#if __CUDA_ARCH__ == GGML_CUDA_CC_TURING - return ncols <= 16 ? 96 : 160; -#else - return ncols <= 16 ? 288 : 160; -#endif // __CUDA_ARCH__ == GGML_CUDA_CC_TURING - } +// ------------------------------------------------------------------------------------------------------------------ - static int get_nbatch_V2_host(const int cc, const int ncols) { - if (ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_TURING) { - return ncols <= 16 ? 64 : 128; - } - return ncols <= 16 ? 256 : 128; - } +static __host__ int ggml_cuda_fattn_mma_get_nstages(const int DKQ, const int DV, const int ncols1, const int ncols2, const int cc) { + return cp_async_available(cc) && ncols2 >= 2 ? ggml_cuda_fattn_mma_get_nstages_target(DKQ, DV, ncols1*ncols2, cc) : 0; +} - static constexpr __device__ int get_nbatch_V2_device(int ncols) { -#if __CUDA_ARCH__ == GGML_CUDA_CC_TURING - return ncols <= 16 ? 64 : 128; +static constexpr __device__ int ggml_cuda_fattn_mma_get_nstages(const int DKQ, const int DV, const int ncols1, const int ncols2) { +#ifdef CP_ASYNC_AVAILABLE + return ncols2 >= 2 ? ggml_cuda_fattn_mma_get_nstages_target(DKQ, DV, ncols1*ncols2) : 0; #else - return ncols <= 16 ? 256 : 128; -#endif // __CUDA_ARCH__ == GGML_CUDA_CC_TURING - } - - static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) { - return 128; - } - - static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) { - return 128; - } -}; + GGML_UNUSED_VARS(DKQ, DV, ncols1, ncols2); + return 0; +#endif // CP_ASYNC_AVAILABLE +} // ------------------------------------------------------------------------------------------------------------------ -template +template static __device__ __forceinline__ void flash_attn_ext_f16_load_tile( - const half2 * const __restrict__ KV, half2 * const __restrict__ tile_KV, const int D2, const int stride_KV) { - + const half2 * const __restrict__ KV, half2 * const __restrict__ tile_KV, const int D2, const int stride_KV, const int i_sup) { // K/V data is loaded with decreasing granularity for D for better memory bandwidth. // The minimum granularity with cp.async is 16 bytes, with synchronous data loading it's 4 bytes. - - if (use_cp_async) { + if constexpr (use_cp_async) { + static_assert(!oob_check, "OOB check not compatible with cp_async"); constexpr int preload = 64; constexpr int h2_per_chunk = 16/sizeof(half2); const int chunks_per_row = D2 / h2_per_chunk; @@ -315,9 +242,15 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile( } } }; - ggml_cuda_unroll<5>{}(load); + // 1: max 32*16=512 bytes, 256 half + // 2: max 16*16=256 bytes, 128 half + // 3: max 8*16=128 bytes, 64 half + // 4: max 4*16= 64 bytes, 32 half + // 5: max 2*16= 32 bytes, 16 half + // 6: max 1*16= 16 bytes, 8 half + ggml_cuda_unroll<6>{}(load); } else { - static_assert(nbatch_fa % (4*nwarps) == 0, "out of bounds"); + // TODO use ggml_cuda_memcpy_1 auto load = [&] __device__ (const int n) { const int stride_k = WARP_SIZE >> n; const int k0_start = stride_k == WARP_SIZE ? 0 : D2 - D2 % (2*stride_k); @@ -340,20 +273,25 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile( for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) { const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k); - tile_KV[i*stride_tile + k] = KV[i*stride_KV + k]; + tile_KV[i*stride_tile + k] = !oob_check || i < i_sup ? KV[i*stride_KV + k] : make_half2(0.0f, 0.0f); } } }; - ggml_cuda_unroll<3>{}(load); + // 1: max 32* 4=128 bytes, 64 half + // 2: max 16* 4= 64 bytes, 32 half + // 3: max 8* 4= 32 bytes, 16 half + // 4: max 4* 4= 16 bytes, 8 half + ggml_cuda_unroll<4>{}(load); } } -template +template static __device__ __forceinline__ void flash_attn_ext_f16_load_mask( - const half2 * const __restrict__ mask_h2, half2 * const __restrict__ tile_mask, const int stride_mask) { - static_assert(nbatch_fa == 2*WARP_SIZE || WARP_SIZE % nbatch_fa == 0, "bad KQ_per_iter"); - - if (use_cp_async) { + const half * const __restrict__ mask_h, half * const __restrict__ tile_mask, + const int stride_mask, const int i_sup, const int j0, const uint3 ne01) { + if constexpr (use_cp_async) { + static_assert(nbatch_fa <= 8*WARP_SIZE && nbatch_fa % 8 == 0, "bad nbatch_fa"); + static_assert(!oob_check, "OOB check incompatible with cp_async"); constexpr int preload = nbatch_fa >= 32 ? nbatch_fa * sizeof(half) : 64; constexpr int cols_per_warp = 8*WARP_SIZE/nbatch_fa; constexpr int stride_j = nwarps * cols_per_warp; @@ -361,50 +299,85 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_mask( const unsigned int tile_mask_32 = ggml_cuda_cvta_generic_to_shared(tile_mask); #pragma unroll - for (int j0 = 0; j0 < ncols1; j0 += stride_j) { - const int j = j0 + threadIdx.y*cols_per_warp + - (nbatch_fa == 2*WARP_SIZE ? threadIdx.x / (WARP_SIZE/4) : threadIdx.x / (WARP_SIZE/cols_per_warp)); + for (int j1 = 0; j1 < ncols1; j1 += stride_j) { + const int j_sram = j1 + threadIdx.y*cols_per_warp + threadIdx.x / (WARP_SIZE/cols_per_warp); + const int j_vram = fastmodulo(j0 + j_sram, ne01); - if (j0 + stride_j > ncols1 && j >= ncols1) { + if (j1 + stride_j > ncols1 && j_sram >= ncols1) { break; } - const int i = 4 * (threadIdx.x % (nbatch_fa/8)); + const int i = 8 * (threadIdx.x % (nbatch_fa/8)); - cp_async_cg_16(tile_mask_32 + j*(nbatch_fa*sizeof(half) + 16) + i*sizeof(half2), mask_h2 + j*stride_mask + i); + cp_async_cg_16(tile_mask_32 + j_sram*(nbatch_fa*sizeof(half) + 16) + i*sizeof(half), mask_h + j_vram*stride_mask + i); } - return; - } + } else if constexpr (oob_check) { +#pragma unroll + for (int j1 = 0; j1 < ncols1; j1 += nwarps) { + const int j_sram = j1 + threadIdx.y; + const int j_vram = fastmodulo(j0 + j_sram, ne01); + + if (j1 + nwarps > ncols1 && j_sram >= ncols1) { + break; + } - constexpr int cols_per_warp = 2*WARP_SIZE/nbatch_fa; - constexpr int stride_j = nwarps * cols_per_warp; #pragma unroll - for (int j0 = 0; j0 < ncols1; j0 += stride_j) { - const int j = j0 + threadIdx.y*cols_per_warp + (nbatch_fa == 2*WARP_SIZE ? 0 : threadIdx.x / (WARP_SIZE/cols_per_warp)); + for (int i0 = 0; i0 < nbatch_fa; i0 += WARP_SIZE) { + const int i = i0 + threadIdx.x; - if (j0 + stride_j > ncols1 && j >= ncols1) { - break; + tile_mask[j_sram*(nbatch_fa + 8) + i] = i < i_sup ? mask_h[j_vram*stride_mask + i] : half(0.0f); + } } + } else if constexpr (nbatch_fa < 2*WARP_SIZE) { + constexpr int cols_per_warp = 2*WARP_SIZE/nbatch_fa; + constexpr int stride_j = nwarps * cols_per_warp; +#pragma unroll + for (int j1 = 0; j1 < ncols1; j1 += stride_j) { + const int j_sram = j1 + threadIdx.y*cols_per_warp + threadIdx.x / (WARP_SIZE/cols_per_warp); + const int j_vram = fastmodulo(j0 + j_sram, ne01); - const int i = nbatch_fa == 2*WARP_SIZE ? threadIdx.x : threadIdx.x % (WARP_SIZE/cols_per_warp); + if (j1 + stride_j > ncols1 && j_sram >= ncols1) { + break; + } + + const int i = threadIdx.x % (WARP_SIZE/cols_per_warp); - tile_mask[j*(nbatch_fa/2 + 4) + i] = mask_h2[j*stride_mask + i]; + ggml_cuda_memcpy_1(tile_mask + j_sram*(nbatch_fa + 8) + 2*i, mask_h + j_vram*stride_mask + 2*i); + } + } else { +#pragma unroll + for (int j1 = 0; j1 < ncols1; j1 += nwarps) { + const int j_sram = j1 + threadIdx.y; + const int j_vram = fastmodulo(j0 + j_sram, ne01); + + if (j1 + nwarps > ncols1 && j_sram >= ncols1) { + break; + } + +#pragma unroll + for (int i0 = 0; i0 < nbatch_fa; i0 += 2*WARP_SIZE) { + const int i = i0 + 2*threadIdx.x; + + ggml_cuda_memcpy_1(tile_mask + j_sram*(nbatch_fa + 8) + i, mask_h + j_vram*stride_mask + i); + } + } } } -template +template static __device__ __forceinline__ void flash_attn_ext_f16_iter( const float2 * const __restrict__ Q_f2, const half2 * const __restrict__ K_h2, const half2 * const __restrict__ V_h2, - const half2 * const __restrict__ mask_h2, + const half * const __restrict__ mask_h, float2 * const __restrict__ dstk, float2 * const __restrict__ dstk_fixup, const float scale, const float slope, const float logit_softcap, - const int ne01, + const uint3 ne01, const int ne02, const int stride_K, const int stride_V, @@ -412,27 +385,24 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( half2 * const __restrict__ tile_Q, half2 * const __restrict__ tile_K, half2 * const __restrict__ tile_V, - half2 * const __restrict__ tile_mask, - const tile_B * const __restrict__ Q_B, - tile_C_VKQ * const __restrict__ VKQ_C, + half * const __restrict__ tile_mask, + T_B_KQ * const __restrict__ Q_B, + T_C_VKQ * const __restrict__ VKQ_C, float * const __restrict__ KQ_max, float * const __restrict__ KQ_rowsum, - const int kb0) { -#ifdef TURING_MMA_AVAILABLE - typedef fattn_mma_f16_config c; - -#ifdef CP_ASYNC_AVAILABLE - constexpr int nstages = c::nstages_target; -#else - constexpr int nstages = 0; -#endif // CP_ASYNC_AVAILABLE - - constexpr int cols_per_warp = ntiles * tile_B::I; - constexpr int cols_per_thread = ntiles == 1 ? 2 : ntiles; - constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column. - constexpr int ncols = ncols1 * ncols2; - constexpr int nbatch_K2 = c::get_nbatch_K2_device(ncols); - constexpr int nbatch_V2 = c::get_nbatch_V2_device(ncols); + const int jt, + const int kb0, + const int k_VKQ_sup) { +#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + constexpr int ncols = ncols1 * ncols2; + constexpr int cols_per_warp = T_B_KQ::I; + constexpr int cols_per_thread = 2; // This is specifically KQ columns, Volta only has a single VKQ column. + constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column. + constexpr int nbatch_fa = ggml_cuda_fattn_mma_get_nbatch_fa(DKQ, DV, ncols); + constexpr int nbatch_K2 = ggml_cuda_fattn_mma_get_nbatch_K2(DKQ, DV, ncols); + constexpr int nbatch_V2 = ggml_cuda_fattn_mma_get_nbatch_V2(DKQ, DV, ncols); + constexpr bool Q_in_reg = ggml_cuda_fattn_mma_get_Q_in_reg (DKQ, DV, ncols); + constexpr int nstages = ggml_cuda_fattn_mma_get_nstages (DKQ, DV, ncols1, ncols2); constexpr int stride_tile_Q = DKQ/2 + 4; constexpr int stride_tile_K = nbatch_K2 + 4; @@ -440,26 +410,27 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( static_assert(!mla || nbatch_K2 >= nbatch_V2, "bad nbatch_K2, nbatch_V2 for MLA"); constexpr int stride_tile_V = mla ? stride_tile_K : nbatch_V2 + 4; - const int k_VKQ_0 = kb0 * c::nbatch_fa; - tile_C_KQ KQ_C[c::nbatch_fa/(np*tile_C_KQ::I) * ntiles]; - - // Use wide variants of tiles if ntiles >= 2. - tile_B_16 * Q_B_16 = (tile_B_16 *) Q_B; - tile_C_VKQ_16 * VKQ_C_16 = (tile_C_VKQ_16 *) VKQ_C; - tile_C_KQ_16 * KQ_C_16 = (tile_C_KQ_16 *) KQ_C; + const int k_VKQ_0 = kb0 * nbatch_fa; +#if defined(TURING_MMA_AVAILABLE) + T_C_KQ KQ_C[nbatch_fa/(np*(cols_per_warp == 8 ? T_C_KQ::I : T_C_KQ::J))]; +#else // Volta + T_C_KQ KQ_C[nbatch_fa/(np*T_C_KQ::J)]; +#endif // defined(TURING_MMA_AVAILABLE) if constexpr (nstages > 1) { + static_assert(!oob_check, "OOB check incompatible with multi-stage pipeline"); static_assert(!mla, "multi-stage loading not implemented for MLA"); static_assert(nbatch_K2 == DKQ/2, "batching not implemented for multi stage loading"); constexpr bool use_cp_async = true; cp_async_wait_all(); __syncthreads(); - flash_attn_ext_f16_load_tile - (V_h2 + int64_t(k_VKQ_0)*stride_V, tile_V, nbatch_V2, stride_V); + flash_attn_ext_f16_load_tile + (V_h2 + int64_t(k_VKQ_0)*stride_V, tile_V, nbatch_V2, stride_V, k_VKQ_sup); } else { constexpr bool use_cp_async = nstages == 1; - if (ncols2 > 1 || mask_h2) { - flash_attn_ext_f16_load_mask(mask_h2 + k_VKQ_0/2, tile_mask, stride_mask); + if (ncols2 > 1 || mask_h) { + flash_attn_ext_f16_load_mask + (mask_h + k_VKQ_0, tile_mask, stride_mask, k_VKQ_sup, jt*ncols1, ne01); } } @@ -468,10 +439,10 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( const int k0_stop = k0_start + nbatch_K2 < DKQ/2 ? k0_start + nbatch_K2 : DKQ/2; const int k0_diff = k0_stop - k0_start; - if (nstages <= 1) { + if constexpr (nstages <= 1) { constexpr bool use_cp_async = nstages == 1; - flash_attn_ext_f16_load_tile - (K_h2 + int64_t(k_VKQ_0)*stride_K + k0_start, tile_K, k0_diff, stride_K); + flash_attn_ext_f16_load_tile + (K_h2 + int64_t(k_VKQ_0)*stride_K + k0_start, tile_K, k0_diff, stride_K, k_VKQ_sup); if (use_cp_async) { cp_async_wait_all(); } @@ -479,55 +450,53 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( } // Calculate tile of KQ: - if constexpr (c::Q_in_reg) { + if constexpr (Q_in_reg) { #pragma unroll - for (int i_KQ_00 = 0; i_KQ_00 < c::nbatch_fa; i_KQ_00 += np*tile_A::I) { - const int i_KQ_0 = i_KQ_00 + (threadIdx.y % np)*tile_A::I; + for (int i_KQ_00 = 0; i_KQ_00 < nbatch_fa; i_KQ_00 += np*T_A_KQ::I) { + const int i_KQ_0 = i_KQ_00 + (threadIdx.y % np)*T_A_KQ::I; #pragma unroll - for (int k_KQ_0 = k0_start; k_KQ_0 < k0_stop; k_KQ_0 += tile_A::J) { - tile_A K_A; + for (int k_KQ_0 = k0_start; k_KQ_0 < k0_stop; k_KQ_0 += T_A_KQ::J) { + T_A_KQ K_A; load_ldmatrix(K_A, tile_K + i_KQ_0*stride_tile_K + (k_KQ_0 - k0_start), stride_tile_K); - if (ntiles == 1) { - mma(KQ_C[i_KQ_00/(np*tile_A::I)], K_A, Q_B[k_KQ_0/tile_A::J]); + if constexpr (cols_per_warp == 8) { + mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], K_A, Q_B[k_KQ_0/T_A_KQ::J]); } else { -#pragma unroll - for (int t = 0; t < ntiles/2; ++t) { - // Wide version of KQ_C is column-major => swap A and B. - mma(KQ_C_16[i_KQ_00/(np*tile_A::I) * ntiles/2 + t], Q_B_16[k_KQ_0/tile_A::J * ntiles/2 + t], K_A); - } + // Wide version of KQ_C is column-major => swap A and B. + mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], Q_B[k_KQ_0/T_A_KQ::J], K_A); } } } } else { - static_assert(ntiles == 2, "ntiles != 2 not implemented"); + static_assert(cols_per_warp != 8, "cols_per_warp == 8 not implemented"); #pragma unroll - for (int k_KQ_0 = k0_start; k_KQ_0 < k0_stop; k_KQ_0 += tile_A::J) { - load_ldmatrix(Q_B_16[0], tile_Q + (threadIdx.y / np)*(tile_B_16::I*stride_tile_Q) + k_KQ_0, stride_tile_Q); + for (int k_KQ_0 = k0_start; k_KQ_0 < k0_stop; k_KQ_0 += T_A_KQ::J) { + load_ldmatrix(Q_B[0], tile_Q + (threadIdx.y / np)*(T_B_KQ::I*stride_tile_Q) + k_KQ_0, stride_tile_Q); #pragma unroll - for (int i_KQ_00 = 0; i_KQ_00 < c::nbatch_fa; i_KQ_00 += np*tile_A::I) { - const int i_KQ_0 = i_KQ_00 + (threadIdx.y % np)*tile_A::I; + for (int i_KQ_00 = 0; i_KQ_00 < nbatch_fa; i_KQ_00 += np*T_A_KQ::I) { + const int i_KQ_0 = i_KQ_00 + (threadIdx.y % np)*T_A_KQ::I; - tile_A K_A; + T_A_KQ K_A; load_ldmatrix(K_A, tile_K + i_KQ_0*stride_tile_K + (k_KQ_0 - k0_start), stride_tile_K); // Wide version of KQ_C is column-major => swap A and B. - mma(KQ_C_16[i_KQ_00/(np*tile_A::I)], Q_B_16[0], K_A); + mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], Q_B[0], K_A); } } } - if (nstages <= 1) { + if constexpr (nstages <= 1) { __syncthreads(); // Only needed if tile_K == tile_V. } } if (use_logit_softcap) { - static_assert(c::nbatch_fa % (np*tile_C_KQ::I) == 0, "bad loop size"); + constexpr int stride = cols_per_warp == 8 ? np*T_C_KQ::I : np*T_C_KQ::J; + static_assert(nbatch_fa % stride == 0, "bad loop size"); #pragma unroll - for (int i = 0; i < c::nbatch_fa/(np*tile_C_KQ::I) * ntiles; ++i) { + for (int i = 0; i < nbatch_fa/stride; ++i) { #pragma unroll - for (int l = 0; l < tile_C_KQ::ne; ++l) { + for (int l = 0; l < T_C_KQ::ne; ++l) { KQ_C[i].x[l] = logit_softcap*tanhf(KQ_C[i].x[l]); } } @@ -540,34 +509,35 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( } float KQ_rowsum_add[cols_per_thread] = {0.0f}; - if (ntiles == 1) { - if (ncols2 > 1 || mask_h2) { + if constexpr (cols_per_warp == 8) { + if (ncols2 > 1 || mask_h) { #pragma unroll - for (int i00 = 0; i00 < c::nbatch_fa; i00 += np*tile_C_KQ::I) { - const int i0 = i00 + (threadIdx.y % np)*tile_C_KQ::I; + for (int i00 = 0; i00 < nbatch_fa; i00 += np*T_C_KQ::I) { + const int i0 = i00 + (threadIdx.y % np)*T_C_KQ::I; #pragma unroll - for (int l = 0; l < tile_C_KQ::ne; ++l) { - const int i = i0 + tile_C_KQ::get_i(l); - const int j = ((threadIdx.y / np)*tile_C_KQ::J + tile_C_KQ::get_j(l)) / ncols2; + for (int l = 0; l < T_C_KQ::ne; ++l) { + const int i = i0 + T_C_KQ::get_i(l); + const int j = ((threadIdx.y / np)*T_C_KQ::J + T_C_KQ::get_j(l)) / ncols2; - KQ_C[i00/(np*tile_C_KQ::I)].x[l] += slope * - __half2float(((const half *) tile_mask)[j*(c::nbatch_fa + 8) + i]); + KQ_C[i00/(np*T_C_KQ::I)].x[l] += slope * __half2float(tile_mask[j*(nbatch_fa + 8) + i]); } } } // Calculate softmax for each KQ column using the current max. value. // The divisor is stored in KQ_rowsum and will be applied at the end. - static_assert(c::nbatch_fa % (np*tile_C_KQ::I) == 0, "bad loop size"); + static_assert(nbatch_fa % (np*T_C_KQ::I) == 0, "bad loop size"); #pragma unroll - for (int k = 0; k < c::nbatch_fa/(np*tile_C_KQ::I); ++k) { + for (int k0 = 0; k0 < nbatch_fa; k0 += np*T_C_KQ::I) { #pragma unroll - for (int l = 0; l < tile_C_KQ::ne; ++l) { - KQ_max_new[l % 2] = fmaxf(KQ_max_new[l % 2], KQ_C[k].x[l]); + for (int l = 0; l < T_C_KQ::ne; ++l) { + if (!oob_check || k0 + T_C_KQ::get_i(l) < k_VKQ_sup) { + KQ_max_new[l % 2] = fmaxf(KQ_max_new[l % 2], KQ_C[k0/(np*T_C_KQ::I)].x[l]); + } } } - // Values per KQ column are spread across 8 threads, does not need full warp reduce: + // Values per KQ column are spread across 8 threads: #pragma unroll for (int col = 0; col < cols_per_thread; ++col) { #pragma unroll @@ -576,73 +546,78 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( } } - static_assert(c::nbatch_fa % (np*tile_C_KQ::I) == 0, "bad loop size"); + static_assert(nbatch_fa % (np*T_C_KQ::I) == 0, "bad loop size"); #pragma unroll - for (int k = 0; k < c::nbatch_fa/(np*tile_C_KQ::I); ++k) { + for (int k0 = 0; k0 < nbatch_fa; k0 += np*T_C_KQ::I) { #pragma unroll - for (int l = 0; l < tile_C_KQ::ne; ++l) { - KQ_C[k].x[l] = expf(KQ_C[k].x[l] - KQ_max_new[l % 2]); - - KQ_rowsum_add[l % 2] += KQ_C[k].x[l]; + for (int l = 0; l < T_C_KQ::ne; ++l) { + if (!oob_check || k0 + (threadIdx.y % np)*T_C_KQ::I + T_C_KQ::get_i(l) < k_VKQ_sup) { + KQ_C[k0/(np*T_C_KQ::I)].x[l] = expf(KQ_C[k0/(np*T_C_KQ::I)].x[l] - KQ_max_new[l % 2]); + KQ_rowsum_add[l % 2] += KQ_C[k0/(np*T_C_KQ::I)].x[l]; + } else { + KQ_C[k0/(np*T_C_KQ::I)].x[l] = 0.0f; + } } } - } else { // ntiles > 1 - if (ncols2 > 1 || mask_h2) { + } else { // not Turing mma or T_B_KQ::I > 8 + if (ncols2 > 1 || mask_h) { #pragma unroll - for (int i00 = 0; i00 < c::nbatch_fa; i00 += np*tile_C_KQ_16::J) { - const int i0 = i00 + (threadIdx.y % np)*tile_C_KQ_16::J; + for (int i00 = 0; i00 < nbatch_fa; i00 += np*T_C_KQ::J) { + const int i0 = i00 + (threadIdx.y % np)*T_C_KQ::J; #pragma unroll - for (int t = 0; t < ntiles/2; ++t) { -#pragma unroll - for (int l0 = 0; l0 < tile_C_KQ_16::ne; l0 += 2) { - const int i = (i0 + tile_C_KQ_16::get_j(l0)) / 2; - const int j = ((threadIdx.y / np)*cols_per_warp + t*tile_C_KQ_16::I + tile_C_KQ_16::get_i(l0)) / ncols2; + for (int l0 = 0; l0 < T_C_KQ::ne; l0 += 2) { + const int i = (i0 + T_C_KQ::get_j(l0)) / 2; + const int j = ((threadIdx.y / np)*cols_per_warp + T_C_KQ::get_i(l0)) / ncols2; - const float2 tmp = __half22float2(tile_mask[j*(c::nbatch_fa/2 + 4) + i]); - const int KQ_index = i00/(np*tile_C_KQ_16::J) * ntiles/2 + t; - KQ_C_16[KQ_index].x[l0 + 0] += slope*tmp.x; - KQ_C_16[KQ_index].x[l0 + 1] += slope*tmp.y; - } + const float2 tmp = __half22float2(((const half2 *)tile_mask)[j*(nbatch_fa/2 + 4) + i]); + KQ_C[i00/(np*T_C_KQ::J)].x[l0 + 0] += slope*tmp.x; + KQ_C[i00/(np*T_C_KQ::J)].x[l0 + 1] += slope*tmp.y; } } } // Calculate softmax for each KQ column using the current max. value. // The divisor is stored in KQ_rowsum and will be applied at the end. - static_assert(c::nbatch_fa % (np*tile_C_KQ::I) == 0, "bad loop size"); -#pragma unroll - for (int k = 0; k < c::nbatch_fa/(np*tile_C_KQ_16::J); ++k) { + static_assert(nbatch_fa % (np*T_C_KQ::J) == 0, "bad loop size"); #pragma unroll - for (int t = 0; t < ntiles/2; ++t) { + for (int k0 = 0; k0 < nbatch_fa; k0 += np*T_C_KQ::J) { #pragma unroll - for (int l = 0; l < tile_C_KQ_16::ne; ++l) { - const int KQ_index = 2*t + (l/2) % 2; - KQ_max_new[KQ_index] = fmaxf(KQ_max_new[KQ_index], KQ_C_16[k*ntiles/2 + t].x[l]); + for (int l = 0; l < T_C_KQ::ne; ++l) { + if (!oob_check || k0 + T_C_KQ::get_j(l) < k_VKQ_sup) { + // Turing + Volta: + KQ_max_new[(l/2) % 2] = fmaxf(KQ_max_new[(l/2) % 2], KQ_C[(k0/(np*T_C_KQ::J))].x[l]); } } } - // Values per KQ column are spread across 4 threads, does not need full warp reduce: #pragma unroll for (int col = 0; col < cols_per_thread; ++col) { +#if defined(TURING_MMA_AVAILABLE) + // Values per KQ column are spread across 4 threads: + constexpr int offset_first = 2; + constexpr int offset_last = 1; +#else + // Values per KQ column are spread across 2 threads: + constexpr int offset_first = 2; + constexpr int offset_last = 2; +#endif // defined(TURING_MMA_AVAILABLE) #pragma unroll - for (int offset = 2; offset >= 1; offset >>= 1) { + for (int offset = offset_first; offset >= offset_last; offset >>= 1) { KQ_max_new[col] = fmaxf(KQ_max_new[col], __shfl_xor_sync(0xFFFFFFFF, KQ_max_new[col], offset, WARP_SIZE)); } } - static_assert(c::nbatch_fa % (np*tile_C_KQ_16::J) == 0, "bad loop size"); + static_assert(nbatch_fa % (np*T_C_KQ::J) == 0, "bad loop size"); #pragma unroll - for (int k = 0; k < c::nbatch_fa/(np*tile_C_KQ_16::J); ++k) { + for (int k0 = 0; k0 < nbatch_fa; k0 += np*T_C_KQ::J) { #pragma unroll - for (int t = 0; t < ntiles/2; ++t) { -#pragma unroll - for (int l = 0; l < tile_C_KQ_16::ne; ++l) { - const int KQ_index = 2*t + (l/2) % 2; - - KQ_C_16[k*ntiles/2 + t].x[l] = expf(KQ_C_16[k*ntiles/2 + t].x[l] - KQ_max_new[KQ_index]); - - KQ_rowsum_add[KQ_index] += KQ_C_16[k*ntiles/2 + t].x[l]; + for (int l = 0; l < T_C_KQ::ne; ++l) { + // Turing + Volta: + if (!oob_check || k0 + (threadIdx.y % np)*T_C_KQ::J + T_C_KQ::get_j(l) < k_VKQ_sup) { + KQ_C[(k0/(np*T_C_KQ::J))].x[l] = expf(KQ_C[(k0/(np*T_C_KQ::J))].x[l] - KQ_max_new[(l/2) % 2]); + KQ_rowsum_add[(l/2) % 2] += KQ_C[(k0/(np*T_C_KQ::J))].x[l]; + } else { + KQ_C[(k0/(np*T_C_KQ::J))].x[l] = 0.0f; } } } @@ -662,12 +637,13 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( KQ_rowsum[col] = KQ_max_scale[col]*KQ_rowsum[col] + KQ_rowsum_add[col]; } - if (ntiles == 1) { +#if defined(TURING_MMA_AVAILABLE) + if constexpr (cols_per_warp == 8) { const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[1]); #pragma unroll - for (int i = 0; i < DV/tile_C_VKQ::I; ++i) { + for (int i = 0; i < DV/T_C_VKQ::I; ++i) { #pragma unroll - for (int l = 0; l < tile_C_VKQ::ne; ++l) { + for (int l = 0; l < T_C_VKQ::ne; ++l) { VKQ_C[i].x[l] *= KQ_max_scale_h2; } } @@ -676,46 +652,53 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( for (int col = 0; col < cols_per_thread; ++col) { const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[col], KQ_max_scale[col]); #pragma unroll - for (int i = 0; i < DV/tile_C_VKQ_16::J; ++i) { + for (int i = 0; i < (DV/2)/T_C_VKQ::J; ++i) { #pragma unroll - for (int l0 = 0; l0 < tile_C_VKQ_16::ne; l0 += 2) { - VKQ_C_16[i*ntiles/2 + col/2].x[l0 + col % 2] *= KQ_max_scale_h2; + for (int l0 = 0; l0 < T_C_VKQ::ne; l0 += 2) { + VKQ_C[i].x[l0 + col] *= KQ_max_scale_h2; } } } } +#else // Volta + const half2 KQ_max_scale_h2 = make_half2( + KQ_max_scale[(threadIdx.x / 2) % 2], KQ_max_scale[(threadIdx.x / 2) % 2]); +#pragma unroll + for (int i = 0; i < (DV/2)/T_C_VKQ::J; ++i) { +#pragma unroll + for (int l = 0; l < T_C_VKQ::ne; ++l) { + VKQ_C[i].x[l] *= KQ_max_scale_h2; + } + } +#endif // defined(TURING_MMA_AVAILABLE) } // Convert KQ C tiles into B tiles for VKQ calculation: - tile_B B[c::nbatch_fa/(np*2*tile_B::J) * ntiles]; - tile_B_16 * B_16 = (tile_B_16 *) B; - static_assert(c::nbatch_fa % (np*2*tile_B::J) == 0, "bad loop size"); - if (ntiles == 1) { + T_B_VKQ B[nbatch_fa/(np*2*T_B_VKQ::J)]; + static_assert(nbatch_fa % (np*2*T_B_VKQ::J) == 0, "bad loop size"); + if constexpr (cols_per_warp == 8) { #pragma unroll - for (int k = 0; k < c::nbatch_fa/(np*2*tile_B::J); ++k) { + for (int k = 0; k < nbatch_fa/(np*2*T_B_VKQ::J); ++k) { B[k] = get_transposed(get_half2(KQ_C[k])); } } else { - for (int k = 0; k < c::nbatch_fa/(np*2*tile_B_16::J); ++k) { -#pragma unroll - for (int t = 0; t < ntiles/2; ++t) { - B_16[k*ntiles/2 + t] = get_half2(KQ_C_16[k*ntiles/2 + t]); - } + for (int k = 0; k < nbatch_fa/(np*2*T_B_VKQ::J); ++k) { + B[k] = get_half2(KQ_C[k]); } } - if (nstages > 1) { + if constexpr (nstages > 1) { // Preload K tile for next iteration: constexpr bool use_cp_async = true; cp_async_wait_all(); __syncthreads(); if (!last_iter) { - if (ncols2 > 1 || mask_h2) { - flash_attn_ext_f16_load_mask - (mask_h2 + (k_VKQ_0 + c::nbatch_fa)/2, tile_mask, stride_mask); + if (ncols2 > 1 || mask_h) { + flash_attn_ext_f16_load_mask + (mask_h + k_VKQ_0 + nbatch_fa, tile_mask, stride_mask, k_VKQ_sup, jt*ncols1, ne01); } - flash_attn_ext_f16_load_tile - (K_h2 + int64_t(k_VKQ_0 + c::nbatch_fa)*stride_K, tile_K, nbatch_K2, stride_K); + flash_attn_ext_f16_load_tile + (K_h2 + int64_t(k_VKQ_0 + nbatch_fa)*stride_K, tile_K, nbatch_K2, stride_K, k_VKQ_sup); } } @@ -724,72 +707,119 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( // Therefore, iterate over V in reverse and re-use the data if possible. static_assert(!mla || nstages <= 1, "combination of MLA and multi-stage loading not implemented"); constexpr int reusable_cutoff = mla ? (DKQ - 1) - (DKQ - 1) % (2*nbatch_K2) - (DKQ - DV) : DV; + + // Calculate VKQ tile, need to use logical rather than physical elements for i0 due to transposition of V: #pragma unroll for (int i0_stop = DV; i0_stop > 0; i0_stop -= 2*nbatch_V2) { const int i0_start = i0_stop - 2*nbatch_V2 > 0 ? i0_stop - 2*nbatch_V2 : 0; const int i0_diff = i0_stop - i0_start; - if (nstages <= 1 && i0_start < reusable_cutoff) { - constexpr bool use_cp_async = nstages == 1; - flash_attn_ext_f16_load_tile - (V_h2 + int64_t(k_VKQ_0)*stride_V + i0_start/2, tile_V, i0_diff/2, stride_V); - if (use_cp_async) { - cp_async_wait_all(); + if constexpr (nstages <= 1) { + if (i0_start < reusable_cutoff) { + constexpr bool use_cp_async = nstages == 1; + flash_attn_ext_f16_load_tile + (V_h2 + int64_t(k_VKQ_0)*stride_V + i0_start/2, tile_V, i0_diff/2, stride_V, k_VKQ_sup); + if (use_cp_async) { + cp_async_wait_all(); + } + __syncthreads(); } - __syncthreads(); } const half2 * tile_V_i = i0_start < reusable_cutoff ? tile_V : tile_V + (i0_start - reusable_cutoff)/2; - // Calculate VKQ tile: +#if defined(TURING_MMA_AVAILABLE) + constexpr int i0_stride = cols_per_warp == 8 ? T_C_VKQ::I : 2*T_C_VKQ::J; #pragma unroll - for (int i_VKQ_0 = i0_start; i_VKQ_0 < i0_stop; i_VKQ_0 += tile_C_VKQ::I) { - static_assert((c::nbatch_fa/2) % (np*tile_A::J) == 0, "bad loop size"); + for (int i_VKQ_0 = i0_start; i_VKQ_0 < i0_stop; i_VKQ_0 += i0_stride) { + static_assert((nbatch_fa/2) % (np*T_A_VKQ::J) == 0, "bad loop size"); #pragma unroll - for (int k00 = 0; k00 < c::nbatch_fa/2; k00 += np*tile_A::J) { - const int k0 = k00 + (threadIdx.y % np)*tile_A::J; + for (int k00 = 0; k00 < nbatch_fa/2; k00 += np*T_A_VKQ::J) { + const int k0 = k00 + (threadIdx.y % np)*T_A_VKQ::J; - tile_A A; + T_A_VKQ A; // Transposed in SRAM but not in registers, gets transposed on load. load_ldmatrix_trans(A, tile_V_i + 2*k0*stride_tile_V + (i_VKQ_0 - i0_start)/2, stride_tile_V); - if (ntiles == 1) { - mma(VKQ_C[i_VKQ_0/tile_C_VKQ::I], A, B[k00/(np*tile_A::J)]); + if constexpr (T_B_KQ::I == 8) { + mma(VKQ_C[i_VKQ_0/i0_stride], A, B[k00/(np*T_A_VKQ::J)]); } else { -#pragma unroll - for (int t = 0; t < ntiles/2; ++t) { - // Wide version of VKQ_C is column-major => swap A and B. - mma(VKQ_C_16[i_VKQ_0/tile_C_VKQ::I * ntiles/2 + t], B_16[k00/(np*tile_A::J) * ntiles/2 + t], A); - } + // Wide version of VKQ_C is column-major => swap A and B. + mma(VKQ_C[i_VKQ_0/i0_stride], B[k00/(np*T_A_VKQ::J)], A); } } } +#else // Volta + constexpr int i0_stride = 2*T_C_VKQ::J; +#pragma unroll + for (int i_VKQ_0 = i0_start; i_VKQ_0 < i0_stop; i_VKQ_0 += i0_stride) { + static_assert(nbatch_fa % (np*T_A_VKQ::I) == 0, "bad loop size"); + static_assert(2*T_B_VKQ::J == T_A_VKQ::I, "bad tile sizes"); +#pragma unroll + for (int k00 = 0; k00 < nbatch_fa; k00 += np*T_A_VKQ::I) { + const int k0 = k00 + (threadIdx.y % np)*T_A_VKQ::I; + + T_A_VKQ A; // Transposed in both SRAM and registers, load normally. + load_ldmatrix(A, tile_V_i + k0*stride_tile_V + (i_VKQ_0 - i0_start)/2, stride_tile_V); + mma(VKQ_C[i_VKQ_0/i0_stride], B[k00/(np*T_A_VKQ::I)], A); + } + } +#endif // defined(TURING_MMA_AVAILABLE) - if (nstages <= 1) { + if constexpr (nstages <= 1) { __syncthreads(); // Only needed if tile_K == tile_V. } } #else - GGML_UNUSED_VARS(Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, + GGML_UNUSED_VARS(Q_f2, K_h2, V_h2, mask_h, dstk, dstk_fixup, scale, slope, logit_softcap, ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0); NO_DEVICE_CODE; -#endif // TURING_MMA_AVAILABLE +#endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) } -template +#if defined(TURING_MMA_AVAILABLE) +template struct mma_tile_sizes { + using T_A_KQ = tile<16, 8, half2>; // row-major + using T_B_KQ = tile<16, 8, half2>; // column-major + using T_C_KQ = tile<16, 16, float>; // column-major + using T_A_VKQ = tile<16, 8, half2>; // row-major + using T_B_VKQ = tile<16, 8, half2>; // column-major + using T_C_VKQ = tile<16, 8, half2>; // column-major +}; +template<> struct mma_tile_sizes<8> { + using T_A_KQ = tile<16, 8, half2>; // row-major + using T_B_KQ = tile< 8, 8, half2>; // column-major + using T_C_KQ = tile<16, 8, float>; // row-major + using T_A_VKQ = tile<16, 8, half2>; // row-major + using T_B_VKQ = tile< 8, 8, half2>; // column-major + using T_C_VKQ = tile<16, 4, half2>; // row-major +}; +#else // Volta +template struct mma_tile_sizes { + using T_A_KQ = tile< 8, 4, half2, DATA_LAYOUT_I_MAJOR_MIRRORED>; // row-major + using T_B_KQ = tile<32, 4, half2, DATA_LAYOUT_I_MAJOR>; // column-major + using T_C_KQ = tile<32, 8, float, DATA_LAYOUT_I_MAJOR>; // column-major + using T_A_VKQ = tile< 8, 4, half2, DATA_LAYOUT_J_MAJOR_MIRRORED>; // column-major + using T_B_VKQ = tile<32, 4, half2, DATA_LAYOUT_I_MAJOR>; // column-major + using T_C_VKQ = tile<32, 4, half2, DATA_LAYOUT_I_MAJOR>; // column-major +}; +#endif // defined(TURING_MMA_AVAILABLE) + +template static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( const float2 * const __restrict__ Q_f2, const half2 * const __restrict__ K_h2, const half2 * const __restrict__ V_h2, - const half2 * const __restrict__ mask_h2, + const half * const __restrict__ mask_h, const float * const __restrict__ sinks_f, float2 * const __restrict__ dstk, float2 * const __restrict__ dstk_fixup, const float scale, const float slope, const float logit_softcap, - const int ne01, + const uint3 ne01, const int ne02, + const int ne11, const int stride_Q1, const int stride_Q2, const int stride_K, @@ -798,23 +828,31 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( const int jt, const int kb0_start, const int kb0_stop) { -#ifdef TURING_MMA_AVAILABLE +#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) //In this kernel Q, K, V are matrices while i, j, k are matrix indices. - typedef fattn_mma_f16_config c; - -#ifdef CP_ASYNC_AVAILABLE - constexpr int nstages = c::nstages_target; -#else - constexpr int nstages = 0; -#endif // CP_ASYNC_AVAILABLE - - constexpr int ncols = ncols1 * ncols2; - constexpr int cols_per_warp = ntiles * tile_B::I; - constexpr int cols_per_thread = ntiles == 1 ? 2 : ntiles; - constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column. - constexpr int nbatch_K2 = c::get_nbatch_K2_device(ncols); - constexpr int nbatch_V2 = c::get_nbatch_V2_device(ncols); + constexpr int ncols = ncols1 * ncols2; + using T_A_KQ = typename mma_tile_sizes::T_A_KQ; + using T_B_KQ = typename mma_tile_sizes::T_B_KQ; + using T_C_KQ = typename mma_tile_sizes::T_C_KQ; + using T_A_VKQ = typename mma_tile_sizes::T_A_VKQ; + using T_B_VKQ = typename mma_tile_sizes::T_B_VKQ; + using T_C_VKQ = typename mma_tile_sizes::T_C_VKQ; + + constexpr int cols_per_warp = T_B_KQ::I; + constexpr int cols_per_thread = 2; // This is specifically KQ columns, Volta only has a single VKQ column. + constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column. + constexpr int nbatch_fa = ggml_cuda_fattn_mma_get_nbatch_fa (DKQ, DV, ncols); + constexpr int nbatch_K2 = ggml_cuda_fattn_mma_get_nbatch_K2 (DKQ, DV, ncols); + constexpr int nbatch_V2 = ggml_cuda_fattn_mma_get_nbatch_V2 (DKQ, DV, ncols); + constexpr int nbatch_combine = ggml_cuda_fattn_mma_get_nbatch_combine(DKQ, DV, ncols); + constexpr bool Q_in_reg = ggml_cuda_fattn_mma_get_Q_in_reg (DKQ, DV, ncols); + constexpr int nstages = ggml_cuda_fattn_mma_get_nstages (DKQ, DV, ncols1, ncols2); + + if (cols_per_warp > ncols) { + NO_DEVICE_CODE; + return; + } static_assert(nwarps * (cols_per_warp/ncols2) % ncols1 == 0, "bad nwarps"); @@ -826,15 +864,16 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( constexpr int stride_tile_KV_max = stride_tile_K > stride_tile_V ? stride_tile_K : stride_tile_V; extern __shared__ half2 tile_Q[]; - half2 * tile_K = c::Q_in_reg ? tile_Q : tile_Q + ncols * stride_tile_Q; - half2 * tile_V = nstages > 1 ? tile_K + c::nbatch_fa * stride_tile_K : tile_K; - half2 * tile_mask = nstages > 1 ? tile_V + c::nbatch_fa * stride_tile_V : tile_V + c::nbatch_fa * stride_tile_KV_max; - - tile_B Q_B[(c::Q_in_reg ? DKQ/(2*tile_B::J) : 1) * ntiles]; - tile_C_VKQ VKQ_C[DV/tile_C_VKQ::I * ntiles]; + half2 * tile_K = Q_in_reg ? tile_Q : tile_Q + ncols * stride_tile_Q; + half2 * tile_V = nstages > 1 ? tile_K + nbatch_fa * stride_tile_K : tile_K; + half * tile_mask = (half *) (nstages > 1 ? tile_V + nbatch_fa * stride_tile_V : tile_V + nbatch_fa * stride_tile_KV_max); - tile_B_16 * Q_B_16 = (tile_B_16 *) Q_B; - tile_C_VKQ_16 * VKQ_C_16 = (tile_C_VKQ_16 *) VKQ_C; + T_B_KQ Q_B[(Q_in_reg ? DKQ/(2*T_B_KQ::J) : 1)]; +#if defined(TURING_MMA_AVAILABLE) + T_C_VKQ VKQ_C[cols_per_warp == 8 ? DV/T_C_VKQ::I : DV/(2*T_C_VKQ::J)]; +#else // Volta + T_C_VKQ VKQ_C[ DV/(2*T_C_VKQ::J)]; +#endif // defined(TURING_MMA_AVAILABLE) float KQ_rowsum[cols_per_thread] = {0.0f}; float KQ_max[cols_per_thread]; @@ -868,7 +907,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( const int j = jc / ncols2; const int c = jc % ncols2; - if (jt*ncols1 + j < ne01) { + if (jt*ncols1 + j < int(ne01.z)) { #pragma unroll for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) { const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k); @@ -889,63 +928,96 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( __syncthreads(); - if (c::Q_in_reg) { + if (Q_in_reg) { const int j0 = (threadIdx.y / np) * cols_per_warp; #pragma unroll - for (int k0 = 0; k0 < DKQ/2; k0 += tile_B::J) { - if (ntiles == 1) { - load_ldmatrix(Q_B[k0/tile_B::J], tile_Q + j0*stride_tile_Q + k0, stride_tile_Q); - } else { -#pragma unroll - for (int t = 0; t < ntiles/2; ++t) { - load_ldmatrix(Q_B_16[k0/tile_B_16::J * ntiles/2 + t], - tile_Q + (j0 + t*tile_B_16::I)*stride_tile_Q + k0, stride_tile_Q); - } - } + for (int k0 = 0; k0 < DKQ/2; k0 += T_B_KQ::J) { + load_ldmatrix(Q_B[k0/T_B_KQ::J], tile_Q + j0*stride_tile_Q + k0, stride_tile_Q); } } __syncthreads(); + int kb0 = kb0_start; + // Preload mask and K data for first iteration when using cp_async with multiple stages: if constexpr (nstages > 1) { static_assert(nbatch_K2 == DKQ/2, "batching not implemented for multi-stage pipeline"); constexpr bool use_cp_async = true; - if (ncols2 > 1 || mask_h2) { - flash_attn_ext_f16_load_mask - (mask_h2 + kb0_start*c::nbatch_fa/2, tile_mask, stride_mask); + constexpr bool oob_check = false; + constexpr int k_VKQ_sup = nbatch_fa; + if (ncols2 > 1 || mask_h) { + flash_attn_ext_f16_load_mask + (mask_h + kb0*nbatch_fa, tile_mask, stride_mask, k_VKQ_sup, jt*ncols1, ne01); } - flash_attn_ext_f16_load_tile - (K_h2 + int64_t(kb0_start)*c::nbatch_fa*stride_K, tile_K, nbatch_K2, stride_K); + flash_attn_ext_f16_load_tile + (K_h2 + int64_t(kb0)*nbatch_fa*stride_K, tile_K, nbatch_K2, stride_K, k_VKQ_sup); } - // Iterate over ne11 == previous tokens: - int kb0 = kb0_start; for (; kb0 < kb0_stop-1; ++kb0) { constexpr bool last_iter = false; - flash_attn_ext_f16_iter - (Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap, - ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0); - } - { // kb0_start is always < kb0_stop so the last iter can be executed unconditionally. + constexpr bool oob_check = false; + constexpr int k_VKQ_sup = nbatch_fa; + flash_attn_ext_f16_iter + + (Q_f2, K_h2, V_h2, mask_h, dstk, dstk_fixup, scale, slope, logit_softcap, + ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, + KQ_max, KQ_rowsum, jt, kb0, k_VKQ_sup); + } + // kb0_start is always < kb0_stop so the last iter can be executed unconditionally. + if constexpr (ncols2 == 1) { + if (ne11 % nbatch_fa == 0) { + constexpr bool last_iter = true; + constexpr bool oob_check = false; + constexpr int k_VKQ_sup = nbatch_fa; + flash_attn_ext_f16_iter + + (Q_f2, K_h2, V_h2, mask_h, dstk, dstk_fixup, scale, slope, logit_softcap, + ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, + KQ_max, KQ_rowsum, jt, kb0, k_VKQ_sup); + } else { + constexpr bool last_iter = true; + constexpr bool oob_check = true; + const int k_VKQ_sup = ne11 - kb0*nbatch_fa; + flash_attn_ext_f16_iter + + (Q_f2, K_h2, V_h2, mask_h, dstk, dstk_fixup, scale, slope, logit_softcap, + ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, + KQ_max, KQ_rowsum, jt, kb0, k_VKQ_sup); + } + } else { constexpr bool last_iter = true; - flash_attn_ext_f16_iter - (Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap, - ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0); + constexpr bool oob_check = false; + constexpr int k_VKQ_sup = nbatch_fa; + flash_attn_ext_f16_iter + + (Q_f2, K_h2, V_h2, mask_h, dstk, dstk_fixup, scale, slope, logit_softcap, + ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, + KQ_max, KQ_rowsum, jt, kb0, k_VKQ_sup); } // With multi-stage loading there is no __syncthreads at the end of the iter, // there can be a race condition on shared memory access for combining/writing back results. - if (nstages > 1 && nwarps*cols_per_warp > c::nbatch_fa) { + if constexpr (nstages > 1 && nwarps*cols_per_warp > nbatch_fa) { __syncthreads(); } // Finally, sum up partial KQ rowsums. - // The partial sums are spread across 8/4 threads each, does not need full reduce. { - constexpr int offset_first = ntiles == 1 ? 16 : 2; - constexpr int offset_last = ntiles == 1 ? 4 : 1; +#if defined(TURING_MMA_AVAILABLE) + // The partial sums are spread across 8/4 threads. + constexpr int offset_first = cols_per_warp == 8 ? 16 : 2; + constexpr int offset_last = cols_per_warp == 8 ? 4 : 1; +#else // Volta + // The partial sums are spread across 2 threads. + constexpr int offset_first = 2; + constexpr int offset_last = 2; +#endif // defined(TURING_MMA_AVAILABLE) #pragma unroll for (int col = 0; col < cols_per_thread; ++col) { #pragma unroll @@ -962,8 +1034,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( float KQ_max_scale[cols_per_thread]; #pragma unroll for (int col = 0; col < cols_per_thread; ++col) { - static_assert(ntiles == 1 || ntiles == 2, "ntiles > 2 not implemented"); - const int jc = ntiles == 1 ? 2*tile_C_VKQ::get_j(col/2) + col % 2 : tile_C_VKQ_16::get_i(col); + const int jc = cols_per_warp == 8 ? T_C_KQ::get_j(col) : T_C_KQ::get_i(2*col); const float sink = sinks_f[jc % ncols2]; const float KQ_max_new = fmaxf(KQ_max[col], sink); @@ -977,12 +1048,13 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( KQ_rowsum[col] = KQ_max_scale[col]*KQ_rowsum[col] + KQ_max_add; } - if (ntiles == 1) { +#if defined(TURING_MMA_AVAILABLE) + if constexpr (cols_per_warp == 8) { const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[1]); #pragma unroll - for (int i = 0; i < DV/tile_C_VKQ::I; ++i) { + for (int i = 0; i < DV/T_C_VKQ::I; ++i) { #pragma unroll - for (int l = 0; l < tile_C_VKQ::ne; ++l) { + for (int l = 0; l < T_C_VKQ::ne; ++l) { VKQ_C[i].x[l] *= KQ_max_scale_h2; } } @@ -991,30 +1063,40 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( for (int col = 0; col < cols_per_thread; ++col) { const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[col], KQ_max_scale[col]); #pragma unroll - for (int i = 0; i < DV/tile_C_VKQ_16::J; ++i) { + for (int i = 0; i < (DV/2)/T_C_VKQ::J; ++i) { #pragma unroll - for (int l0 = 0; l0 < tile_C_VKQ_16::ne; l0 += 2) { - VKQ_C_16[i*ntiles/2 + col/2].x[l0 + col % 2] *= KQ_max_scale_h2; + for (int l0 = 0; l0 < T_C_VKQ::ne; l0 += 2) { + VKQ_C[i].x[l0 + col] *= KQ_max_scale_h2; } } } } +#else // Volta + const int col = (threadIdx.x / 2) % 2; + const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[col], KQ_max_scale[col]); +#pragma unroll + for (int i = 0; i < (DV/2)/T_C_VKQ::J; ++i) { +#pragma unroll + for (int l = 0; l < T_C_VKQ::ne; ++l) { + VKQ_C[i].x[l] *= KQ_max_scale_h2; + } + } +#endif // defined(TURING_MMA_AVAILABLE) } // Combine VKQ accumulator values if np > 1. // It's also faster to do small writes to shared memory, then large write to VRAM than to do small writes to VRAM. // So also write VKQ accumulators to shared memory in column-major format if np == 1. - constexpr int nbatch_combine = c::get_nbatch_combine_device(ncols); - constexpr int tile_stride = nbatch_combine + 4; + constexpr int tile_stride = nbatch_combine + 4; static_assert((DV/2) % nbatch_combine == 0, "bad nbatch_combine"); - if constexpr (ntiles == 1) { - const int jc_cwmo = (threadIdx.x % (2*tile_C_VKQ::J)) / tile_C_VKQ::J; // jc combine write meta offset - const int jc_cwm = threadIdx.y*(2*tile_C_VKQ::J) + 2*tile_C_VKQ::get_j(-1) + jc_cwmo; // jc combine write meta + if constexpr (cols_per_warp == 8) { + const int jc_cwmo = (threadIdx.x % (2*T_C_VKQ::J)) / T_C_VKQ::J; // jc combine write meta offset + const int jc_cwm = threadIdx.y*(2*T_C_VKQ::J) + 2*T_C_VKQ::get_j(-1) + jc_cwmo; // jc combine write meta const float2 KQ_cmr = make_float2(KQ_max[jc_cwmo], KQ_rowsum[jc_cwmo]); // KQ combine max rowsum - if (((!needs_fixup && !is_fixup) || np > 1) && threadIdx.x < 2*tile_C_VKQ::J) { + if (((!needs_fixup && !is_fixup) || np > 1) && threadIdx.x < 2*T_C_VKQ::J) { // Use the 16 bytes of padding in each row to store the meta data: KQ max, KQ rowsum, KQ max scale. ((float2 *) tile_Q)[jc_cwm*(tile_stride/2) + nbatch_combine/2] = KQ_cmr; } @@ -1023,24 +1105,30 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( if (np == 1) { // No combination is needed, the meta data can be directly written from registers to VRAM. - if (needs_fixup && threadIdx.x < tile_B::I) { + if (needs_fixup && threadIdx.x < T_B_KQ::I) { float2 * dstk_fixup_meta = dstk_fixup + blockIdx.x*ncols; dstk_fixup_meta[jc_cwm] = KQ_cmr; } - if (is_fixup && threadIdx.x < tile_B::I) { + if (is_fixup && threadIdx.x < T_B_KQ::I) { float2 * dstk_fixup_meta = dstk_fixup + (gridDim.x + blockIdx.x)*ncols; dstk_fixup_meta[jc_cwm] = KQ_cmr; } } } else { - static_assert(ntiles == 2 || ntiles == 4, "bad ntiles"); - const int jc_cwm = threadIdx.y*cols_per_warp // jc combine write meta - + (ntiles == 4 ? ((threadIdx.x % 4) / 2) * tile_C_VKQ_16::I : 0) - + tile_C_VKQ_16::get_i(threadIdx.x % 4); - const float2 KQ_cmr = make_float2(KQ_max[threadIdx.x % cols_per_thread], KQ_rowsum[threadIdx.x % cols_per_thread]); // KQ combine max rowsum - - if (((!needs_fixup && !is_fixup) || np > 1) && (ntiles == 4 || threadIdx.x % 4 < cols_per_thread)) { - // Use the 16 bytes of padding in each row to store the meta data: KQ max, KQ rowsum, KQ max scale. + // jc_cwm = jc combine write meta + // KQ_cmr = KQ combine max rowsum + // Use the 16 bytes of padding in each Q column to store the meta data: KQ max, KQ rowsum, KQ max scale. +#if defined(TURING_MMA_AVAILABLE) + const int jc_cwm = threadIdx.y*cols_per_warp + T_C_VKQ::get_i(threadIdx.x % 4); + const float2 KQ_cmr = make_float2(KQ_max[threadIdx.x % cols_per_thread], KQ_rowsum[threadIdx.x % cols_per_thread]); + const bool thread_should_write = threadIdx.x % 4 < cols_per_thread; +#else // Volta + const int jc_cwm = threadIdx.y*cols_per_warp + T_C_KQ::get_i(threadIdx.x & 2); + const float2 KQ_cmr = make_float2(KQ_max[(threadIdx.x & 2) / 2], KQ_rowsum[(threadIdx.x & 2) / 2]); + const bool thread_should_write = T_C_KQ::J == 8 || T_C_KQ::get_j(threadIdx.x & 2) < 8; +#endif // defined(TURING_MMA_AVAILABLE) + + if (((!needs_fixup && !is_fixup) || np > 1) && thread_should_write) { ((float2 *) tile_Q)[jc_cwm*(tile_stride/2) + nbatch_combine/2] = KQ_cmr; } @@ -1048,18 +1136,17 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( if (np == 1) { // No combination is needed, the meta data can be directly written from registers to VRAM. - if (needs_fixup && (ntiles == 4 || threadIdx.x % 4 < ntiles)) { + if (needs_fixup && thread_should_write) { float2 * dstk_fixup_meta = dstk_fixup + blockIdx.x*ncols; dstk_fixup_meta[jc_cwm] = KQ_cmr; } - if (is_fixup && (ntiles == 4 || threadIdx.x % 4 < ntiles)) { + if (is_fixup && thread_should_write) { float2 * dstk_fixup_meta = dstk_fixup + (gridDim.x + blockIdx.x)*ncols; dstk_fixup_meta[jc_cwm] = KQ_cmr; } } } - static_assert(np == 1 || ntiles == 1 || ntiles == 2, "bad ntiles"); if (np > 1 && threadIdx.y % np == 0) { // Combine the meta data for parallel warps via shared memory. // Warps with threadIdx.y % np != 0 must NOT return early. @@ -1135,32 +1222,29 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( #pragma unroll for (int k00 = 0; k00 < DV/2; k00 += nbatch_combine) { - if (ntiles == 1) { - const int jc_cwd = threadIdx.y*tile_B::I + tile_B::get_i(-1); // jc combine write data + if constexpr (cols_per_warp == 8) { + const int jc_cwd = threadIdx.y*T_B_KQ::I + T_B_KQ::get_i(-1); // jc combine write data #pragma unroll - for (int k0 = 0; k0 < nbatch_combine; k0 += tile_B::J) { - const tile_B B = get_transposed(VKQ_C[(k00 + k0)/tile_B::J]); // Conversion of C to B matrix puts it in column-major format. + for (int k1 = 0; k1 < nbatch_combine; k1 += T_B_KQ::J) { + const T_B_KQ B = get_transposed(VKQ_C[(k00 + k1)/T_B_KQ::J]); // Conversion of C to B matrix puts it in column-major format. #pragma unroll - for (int l = 0; l < tile_B::ne; ++l) { - const int k = k0 + tile_B::get_j(l); + for (int l = 0; l < T_B_KQ::ne; ++l) { + const int k = k1 + T_B_KQ::get_j(l); tile_Q[jc_cwd*tile_stride + k] = B.x[l]; } } } else { + const int j0 = threadIdx.y*cols_per_warp; #pragma unroll - for (int t = 0; t < ntiles/2; ++t) { - const int j0 = threadIdx.y*cols_per_warp + t*tile_C_VKQ_16::I; + for (int k1 = 0; k1 < nbatch_combine; k1 += T_C_VKQ::J) { #pragma unroll - for (int k0 = 0; k0 < nbatch_combine; k0 += tile_C_VKQ_16::J) { -#pragma unroll - for (int l = 0; l < tile_C_VKQ_16::ne; ++l) { - const int j = j0 + tile_C_VKQ_16::get_i(l); - const int k = k0 + tile_C_VKQ_16::get_j(l); + for (int l = 0; l < T_C_VKQ::ne; ++l) { + const int j = j0 + T_C_VKQ::get_i(l); + const int k = k1 + T_C_VKQ::get_j(l); - tile_Q[j*tile_stride + k] = VKQ_C_16[(k00 + k0)/tile_C_VKQ_16::J * ntiles/2 + t].x[l]; - } + tile_Q[j*tile_stride + k] = VKQ_C[(k00 + k1)/T_C_VKQ::J].x[l]; } } } @@ -1195,7 +1279,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( const int j_dst = jc_dst / ncols2; const int c_dst = jc_dst % ncols2; - if (!is_fixup && jt*ncols1 + j_dst >= ne01) { + if (!is_fixup && jt*ncols1 + j_dst >= int(ne01.z)) { continue; } @@ -1233,16 +1317,16 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( } } #else - GGML_UNUSED_VARS(Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dstk_fixup, + GGML_UNUSED_VARS(Q_f2, K_h2, V_h2, mask_h, sinks_f, dstk, dstk_fixup, scale, slope, logit_softcap, ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start, kb0_stop); NO_DEVICE_CODE; -#endif // TURING_MMA_AVAILABLE +#endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) } -template -__launch_bounds__(nwarps*WARP_SIZE, 1) +template +__launch_bounds__(ggml_cuda_fattn_mma_get_nthreads(DKQ, DV, ncols1*ncols2), ggml_cuda_fattn_mma_get_occupancy(DKQ, DV, ncols1*ncols2)) static __global__ void flash_attn_ext_f16( const char * __restrict__ Q, const char * __restrict__ K, @@ -1258,14 +1342,14 @@ static __global__ void flash_attn_ext_f16( const float m1, const uint32_t n_head_log2, const float logit_softcap, - const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03, + const int32_t ne00, const uint3 ne01, const int32_t ne02, const int32_t ne03, const int32_t nb01, const int32_t nb02, const int32_t nb03, const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13, const int32_t nb11, const int32_t nb12, const int64_t nb13, const int32_t nb21, const int32_t nb22, const int64_t nb23, const int32_t ne31, const int32_t ne32, const int32_t ne33, const int32_t nb31, const int32_t nb32, const int64_t nb33) { -#if defined(FLASH_ATTN_AVAILABLE) && defined(TURING_MMA_AVAILABLE) +#if defined(FLASH_ATTN_AVAILABLE) && (defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)) // Skip unused kernel variants for faster compilation: if (use_logit_softcap && !(DKQ == 128 || DKQ == 256)) { @@ -1281,23 +1365,22 @@ static __global__ void flash_attn_ext_f16( static_assert(!mla || DKQ >= DV, "MLA needs DKQ >= DV"); - typedef fattn_mma_f16_config c; - - static_assert(FATTN_KQ_STRIDE % fattn_mma_f16_config::nbatch_fa == 0, "bad nbatch_fa"); + constexpr int ncols = ncols1 * ncols2; + constexpr int nbatch_fa = ggml_cuda_fattn_mma_get_nbatch_fa(DKQ, DV, ncols); + constexpr int nthreads = ggml_cuda_fattn_mma_get_nthreads(DKQ, DV, ncols); + constexpr int nwarps = nthreads / WARP_SIZE; const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. const int stride_Q1 = nb01 / sizeof(float2); const int stride_Q2 = nb02 / sizeof(float2); const int stride_K = nb11 / sizeof(half2); - const int stride_mask = nb31 / sizeof(half2); + const int stride_mask = nb31 / sizeof(half); const int stride_V = mla ? stride_K : nb21 / sizeof(half2); - const int iter_k = ne11 / FATTN_KQ_STRIDE; - const int iter_j = (ne01 + (ncols1 - 1)) / ncols1; - - constexpr int kb_niter = FATTN_KQ_STRIDE / c::nbatch_fa; // Number of kernel iterations per assigned KQ slice. + const int iter_k = (ne11 + (nbatch_fa - 1)) / nbatch_fa; + const int iter_j = (ne01.z + (ncols1 - 1)) / ncols1; // kbc == k block continuous, current index in continuous ijk space. int kbc = (blockIdx.x + 0)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x; @@ -1318,35 +1401,31 @@ static __global__ void flash_attn_ext_f16( const int head0 = zt * ncols2; - const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02* head0); - const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head0 / gqa_ratio)); - const half2 * mask_h2 = ncols2 == 1 && !mask ? nullptr : - (const half2 *) (mask + nb33*(sequence % ne33) + nb31*jt*ncols1); - float2 * dstk = ((float2 *) dst) + (sequence*ne01*ne02 + head0) * (DV/2); + const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02* head0); + const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head0 / gqa_ratio)); + const half * mask_h = ncols2 == 1 && !mask ? nullptr : + (const half *) (mask + nb33*(sequence % ne33)); + float2 * dstk = ((float2 *) dst) + (sequence*ne01.z*ne02 + head0) * (DV/2); const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio)); const float * sinks_f = sinks ? (const float *) sinks + head0 : nullptr; const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head0, n_head_log2, m0, m1) : 1.0f; - const int kb0_start_kernel = kb0_start * kb_niter; - int kb0_stop_kernel = kb0_stop * kb_niter; - if (KV_max) { - kb0_stop_kernel = min(kb0_stop_kernel, KV_max[sequence*iter_j + jt] / c::nbatch_fa); + kb0_stop = min(kb0_stop, KV_max[sequence*iter_j + jt] / nbatch_fa); } - constexpr bool is_fixup = false; // All but (potentially) the last iterations write their data to dst rather than the fixup buffer. if (kb0_start == 0) { constexpr bool needs_fixup = false; // CUDA block is working on an entire tile. - flash_attn_ext_f16_process_tile - (Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dst_meta, scale, slope, logit_softcap, - ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel); + flash_attn_ext_f16_process_tile + (Q_f2, K_h2, V_h2, mask_h, sinks_f, dstk, dst_meta, scale, slope, logit_softcap, + ne01, ne02, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start, kb0_stop); } else { - constexpr bool needs_fixup = true; // CUDA block is working on the beginning of a tile. - flash_attn_ext_f16_process_tile - (Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dst_meta, scale, slope, logit_softcap, - ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel); + constexpr bool needs_fixup = true; // CUDA block is missing the beginning of a tile. + flash_attn_ext_f16_process_tile + (Q_f2, K_h2, V_h2, mask_h, sinks_f, dstk, dst_meta, scale, slope, logit_softcap, + ne01, ne02, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start, kb0_stop); } kbc += iter_k; @@ -1366,29 +1445,26 @@ static __global__ void flash_attn_ext_f16( const int head0 = zt * ncols2; - const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02* head0); - const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head0 / gqa_ratio)); - const half2 * mask_h2 = ncols2 == 1 && !mask ? nullptr : - (const half2 *) (mask + nb33*(sequence % ne33) + nb31*jt*ncols1); - float2 * dstk = ((float2 *) dst) + (sequence*ne01*ne02 + head0) * (DV/2); + const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02* head0); + const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head0 / gqa_ratio)); + const half * mask_h = ncols2 == 1 && !mask ? nullptr : + (const half *) (mask + nb33*(sequence % ne33)); + float2 * dstk = ((float2 *) dst) + (sequence*ne01.z*ne02 + head0) * (DV/2); const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio)); const float * sinks_f = sinks ? (const float *) sinks + head0 : nullptr; const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head0, n_head_log2, m0, m1) : 1.0f; - const int kb0_start_kernel = kb0_start * kb_niter; - int kb0_stop_kernel = kb0_stop * kb_niter; - if (KV_max) { - kb0_stop_kernel = min(kb0_stop_kernel, KV_max[sequence*iter_j + jt] / c::nbatch_fa); + kb0_stop = min(kb0_stop, KV_max[sequence*iter_j + jt] / nbatch_fa); } constexpr bool is_fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks. constexpr bool needs_fixup = false; - flash_attn_ext_f16_process_tile - (Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dst_meta, scale, slope, logit_softcap, - ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel); + flash_attn_ext_f16_process_tile + (Q_f2, K_h2, V_h2, mask_h, sinks_f, dstk, dst_meta, scale, slope, logit_softcap, + ne01, ne02, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start, kb0_stop); #else GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale, max_bias, m0, m1, n_head_log2, logit_softcap, @@ -1400,7 +1476,7 @@ static __global__ void flash_attn_ext_f16( ne31, ne32, ne33, nb31, nb32, nb33); NO_DEVICE_CODE; -#endif // defined(FLASH_ATTN_AVAILABLE) && defined(TURING_MMA_AVAILABLE) +#endif // defined(FLASH_ATTN_AVAILABLE) && (defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)) } template @@ -1409,36 +1485,30 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml const int id = ggml_cuda_get_device(); const int cc = ggml_cuda_info().devices[id].cc; - typedef fattn_mma_f16_config c; + constexpr int ncols = ncols1 * ncols2; - const int nstages = cp_async_available(cc) ? c::nstages_target : 0; + const int nthreads = ggml_cuda_fattn_mma_get_nthreads (DKQ, DV, ncols, cc); + const int nbatch_fa = ggml_cuda_fattn_mma_get_nbatch_fa (DKQ, DV, ncols, cc); + const int nbatch_K2 = ggml_cuda_fattn_mma_get_nbatch_K2 (DKQ, DV, ncols, cc); + const int nbatch_V2 = ggml_cuda_fattn_mma_get_nbatch_V2 (DKQ, DV, ncols, cc); + const int nbatch_combine = ggml_cuda_fattn_mma_get_nbatch_combine(DKQ, DV, ncols, cc); + const bool Q_in_reg = ggml_cuda_fattn_mma_get_Q_in_reg (DKQ, DV, ncols, cc); + const int nstages = ggml_cuda_fattn_mma_get_nstages (DKQ, DV, ncols1, ncols2, cc); - constexpr int ncols = ncols1 * ncols2; - constexpr int ntiles = ncols <= 8 ? 1 : 2; // Number of tiles per warp. - constexpr int cols_per_warp = ntiles * tile_B::I; - constexpr int nwarps_max_x = ncols / cols_per_warp; - constexpr int nwarps_max_y = c::nbatch_fa / tile_A::I; - constexpr int nwarps = nwarps_max_x*nwarps_max_y <= c::nwarps_max ? nwarps_max_x*nwarps_max_y : c::nwarps_max; + const int cols_per_warp = std::min(ncols, turing_mma_available(cc) ? 16 : 32); + const int nwarps = nthreads / WARP_SIZE; constexpr bool mla = DKQ == 576; - const int nbatch_K2 = c::get_nbatch_K2_host (cc, ncols); - const int nbatch_V2 = c::get_nbatch_K2_host (cc, ncols); - const int nbatch_combine = c::get_nbatch_combine_host(cc, ncols); - - static_assert(DKQ % tile_B::J == 0, "bad DKQ"); - static_assert(DV % tile_A::J == 0, "bad DV"); - static_assert(ncols % cols_per_warp == 0, "bad ncols"); - - const size_t nbytes_shared_KV_1stage = c::nbatch_fa * std::max(nbatch_K2 + 4, nbatch_V2 + 4) * sizeof(half2); - const size_t nbytes_shared_KV_2stage = c::nbatch_fa * (nbatch_K2 + 4 + nbatch_V2 + 4) * sizeof(half2); + const size_t nbytes_shared_KV_1stage = nbatch_fa * std::max(nbatch_K2 + 4, nbatch_V2 + 4) * sizeof(half2); + const size_t nbytes_shared_KV_2stage = nbatch_fa * (nbatch_K2 + 4 + nbatch_V2 + 4) * sizeof(half2); const size_t nbytes_shared_Q = ncols * (DKQ/2 + 4) * sizeof(half2); - const size_t nbytes_shared_mask = ncols1 * (c::nbatch_fa/2 + 4) * sizeof(half2); + const size_t nbytes_shared_mask = ncols1 * (nbatch_fa/2 + 4) * sizeof(half2); const size_t nbytes_shared_combine = nwarps*cols_per_warp * (nbatch_combine + 4) * sizeof(half2); const size_t nbytes_shared_KV = nstages <= 1 ? nbytes_shared_KV_1stage : nbytes_shared_KV_2stage; - const size_t nbytes_shared_total = std::max(nbytes_shared_combine, c::Q_in_reg ? + const size_t nbytes_shared_total = std::max(nbytes_shared_combine, Q_in_reg ? std::max(nbytes_shared_Q, nbytes_shared_KV + nbytes_shared_mask) : nbytes_shared_Q + nbytes_shared_KV + nbytes_shared_mask); @@ -1448,7 +1518,7 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml fattn_kernel_t fattn_kernel; if (logit_softcap == 0.0f) { constexpr bool use_logit_softcap = false; - fattn_kernel = flash_attn_ext_f16; + fattn_kernel = flash_attn_ext_f16; #if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false}; @@ -1459,7 +1529,7 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml #endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) } else { constexpr bool use_logit_softcap = true; - fattn_kernel = flash_attn_ext_f16; + fattn_kernel = flash_attn_ext_f16; #if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false}; @@ -1471,7 +1541,7 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml } launch_fattn - (ctx, dst, fattn_kernel, nwarps, nbytes_shared_total, FATTN_KQ_STRIDE, true, true, true); + (ctx, dst, fattn_kernel, nwarps, nbytes_shared_total, nbatch_fa, true, true, true); } diff --git a/ggml/src/ggml-cuda/fattn-tile.cuh b/ggml/src/ggml-cuda/fattn-tile.cuh index 3e58d64ff9d..63b235674eb 100644 --- a/ggml/src/ggml-cuda/fattn-tile.cuh +++ b/ggml/src/ggml-cuda/fattn-tile.cuh @@ -501,6 +501,7 @@ static __device__ __forceinline__ void flash_attn_tile_iter( const half2 * const __restrict__ K_h2, const half2 * const __restrict__ V_h2, const half * const __restrict__ mask, + const uint3 ne01, const float logit_softcap, const float slope, T_KQ * const KQ, @@ -512,7 +513,8 @@ static __device__ __forceinline__ void flash_attn_tile_iter( float * const KQ_sum, T_acc * const VKQ, const int k_VKQ_0, - const int k_VKQ_max) { + const int k_VKQ_max, + const int col_Q_0) { constexpr int cpy_nb = ggml_cuda_get_max_cpy_bytes(); constexpr int cpy_ne = cpy_nb / 4; @@ -556,7 +558,7 @@ static __device__ __forceinline__ void flash_attn_tile_iter( // Apply logit softcap + mask, update KQ_max: #pragma unroll for (int jc0 = 0; jc0 < cpw; ++jc0) { - const int j = (jc0 + (threadIdx.y / np)*cpw)/ncols2; + const int j = fastmodulo(col_Q_0 + (jc0 + (threadIdx.y / np)*cpw)/ncols2, ne01); #pragma unroll for (int i_KQ_0 = 0; i_KQ_0 < nbatch_fa; i_KQ_0 += np*warp_size) { @@ -736,7 +738,7 @@ static __global__ void flash_attn_tile( const float m1, const uint32_t n_head_log2, const float logit_softcap, - const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03, + const int32_t ne00, const uint3 ne01, const int32_t ne02, const int32_t ne03, const int32_t nb01, const int32_t nb02, const int32_t nb03, const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13, const int32_t nb11, const int32_t nb12, const int64_t nb13, @@ -781,11 +783,11 @@ static __global__ void flash_attn_tile( const int sequence = blockIdx.z / (ne02/ncols2); const int head0 = blockIdx.z*ncols2 - sequence*ne02; // == blockIdx.z % (ne02/ncols2) const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. - const float * Q_f = (const float *) (Q + nb03*sequence + nb02* head0 + nb01*col_Q_0); + const float * Q_f = (const float *) (Q + nb03*sequence + nb02* head0); const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head0 / gqa_ratio)); const half2 * V_h2 = (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio)); // K and V have same shape - const half * maskh = mask ? (const half *) (mask + nb33*(sequence % ne33) + nb31*col_Q_0) : nullptr; + const half * maskh = mask ? (const half *) (mask + nb33*(sequence % ne33)) : nullptr; const int stride_K2 = nb11 / sizeof(half2); const int stride_V2 = nb21 / sizeof(half2); @@ -842,11 +844,9 @@ static __global__ void flash_attn_tile( for (int i0 = 0; i0 < DKQp; i0 += np*warp_size*cpy_ne_D) { if (i0 + np*warp_size*cpy_ne_D <= DKQ || i0 + (threadIdx.y % np)*(warp_size*cpy_ne_D) + threadIdx.x*cpy_ne_D < DKQ) { float tmp_f[cpy_ne_D] = {0.0f}; - if (ncols1 == 1 || col_Q_0 + j < ne01) { - ggml_cuda_memcpy_1 - (tmp_f, &Q_f[c*(nb02/sizeof(float)) + j*(nb01/sizeof(float)) - + i0 + (threadIdx.y % np)*(warp_size*cpy_ne_D) + threadIdx.x*cpy_ne_D]); - } + ggml_cuda_memcpy_1 + (tmp_f, &Q_f[c*(nb02/sizeof(float)) + fastmodulo(col_Q_0 + j, ne01)*(nb01/sizeof(float)) + + i0 + (threadIdx.y % np)*(warp_size*cpy_ne_D) + threadIdx.x*cpy_ne_D]); #pragma unroll for (int i1 = 0; i1 < cpy_ne_D; ++i1) { @@ -881,23 +881,23 @@ static __global__ void flash_attn_tile( while (k_VKQ_0 < k_VKQ_max - nbatch_fa) { constexpr bool oob_check = false; flash_attn_tile_iter - (Q_tmp, K_h2, V_h2, maskh, logit_softcap, slope, KQ, KV_tmp, - stride_K2, stride_V2, stride_mask, KQ_max, KQ_sum, VKQ, k_VKQ_0, k_VKQ_max); + (Q_tmp, K_h2, V_h2, maskh, ne01, logit_softcap, slope, KQ, KV_tmp, + stride_K2, stride_V2, stride_mask, KQ_max, KQ_sum, VKQ, k_VKQ_0, k_VKQ_max, col_Q_0); k_VKQ_0 += gridDim.y*nbatch_fa; } if (k_VKQ_0 < k_VKQ_max) { constexpr bool oob_check = true; flash_attn_tile_iter - (Q_tmp, K_h2, V_h2, maskh, logit_softcap, slope, KQ, KV_tmp, - stride_K2, stride_V2, stride_mask, KQ_max, KQ_sum, VKQ, k_VKQ_0, k_VKQ_max); + (Q_tmp, K_h2, V_h2, maskh, ne01, logit_softcap, slope, KQ, KV_tmp, + stride_K2, stride_V2, stride_mask, KQ_max, KQ_sum, VKQ, k_VKQ_0, k_VKQ_max, col_Q_0); } } else { // Branch without out-of-bounds checks. for (int k_VKQ_0 = blockIdx.y*nbatch_fa; k_VKQ_0 < k_VKQ_max; k_VKQ_0 += gridDim.y*nbatch_fa) { constexpr bool oob_check = false; flash_attn_tile_iter - (Q_tmp, K_h2, V_h2, maskh, logit_softcap, slope, KQ, KV_tmp, - stride_K2, stride_V2, stride_mask, KQ_max, KQ_sum, VKQ, k_VKQ_0, k_VKQ_max); + (Q_tmp, K_h2, V_h2, maskh, ne01, logit_softcap, slope, KQ, KV_tmp, + stride_K2, stride_V2, stride_mask, KQ_max, KQ_sum, VKQ, k_VKQ_0, k_VKQ_max, col_Q_0); } } @@ -1010,13 +1010,13 @@ static __global__ void flash_attn_tile( const int j = jc / ncols2; const int c = jc % ncols2; - if (ncols1 > 1 && col_Q_0 + j >= ne01) { + if (ncols1 > 1 && col_Q_0 + j >= int(ne01.z)) { return; } const float scale = gridDim.y == 1 ? 1.0f/KQ_sum[jc0] : 1.0f; - const int j_dst_unrolled = ((sequence*ne01 + col_Q_0 + j)*ne02 + head0 + c)*gridDim.y + blockIdx.y; + const int j_dst_unrolled = ((sequence*int(ne01.z) + col_Q_0 + j)*ne02 + head0 + c)*gridDim.y + blockIdx.y; #ifdef FAST_FP16_AVAILABLE constexpr int cpy_ne_D = cpy_ne/2 < (DVp/2)/warp_size ? cpy_ne/2 : (DVp/2)/warp_size; diff --git a/ggml/src/ggml-cuda/fattn-vec.cuh b/ggml/src/ggml-cuda/fattn-vec.cuh index 67aa67ecb94..0bae9849a96 100644 --- a/ggml/src/ggml-cuda/fattn-vec.cuh +++ b/ggml/src/ggml-cuda/fattn-vec.cuh @@ -33,7 +33,7 @@ static __global__ void flash_attn_ext_vec( const float m1, const uint32_t n_head_log2, const float logit_softcap, - const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03, + const int32_t ne00, const uint3 ne01, const int32_t ne02, const int32_t ne03, const int32_t nb01, const int32_t nb02, const int32_t nb03, const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13, const int32_t nb11, const int32_t nb12, const int64_t nb13, @@ -150,7 +150,7 @@ static __global__ void flash_attn_ext_vec( float2 * tmp_q_ds = (float2 *) (tmp_q_i32 + D/sizeof(int)); // Set memory to zero if out of bounds: - if (ncols > 1 && ic0 + j >= ne01) { + if (ncols > 1 && ic0 + j >= int(ne01.z)) { #pragma unroll for (int i0 = 0; i0 < int(D/sizeof(int)); i0 += WARP_SIZE) { const int i = i0 + threadIdx.x; @@ -201,7 +201,7 @@ static __global__ void flash_attn_ext_vec( const int i = i0 + (nthreads_KQ == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_KQ)*cpy_ne; float2 tmp[cpy_ne] = {{0.0f, 0.0f}}; - if (ncols == 1 || ic0 + j < ne01) { + if (ncols == 1 || ic0 + j < int(ne01.z)) { ggml_cuda_memcpy_1(tmp, &Q_j[i]); ggml_cuda_memcpy_1(tmp + cpy_ne/2, &Q_j[i + cpy_ne/2]); } @@ -222,7 +222,7 @@ static __global__ void flash_attn_ext_vec( #pragma unroll for (int i0 = 0; i0 < D/2; i0 += nthreads_KQ*cpy_ne) { const int i = i0 + (nthreads_KQ == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_KQ)*cpy_ne; - if (ncols == 1 || ic0 + j < ne01) { + if (ncols == 1 || ic0 + j < int(ne01.z)) { ggml_cuda_memcpy_1(&Q_reg[j][i0/nthreads_KQ], &Q_j[i]); ggml_cuda_memcpy_1(&Q_reg[j][i0/nthreads_KQ + cpy_ne/2], &Q_j[i + cpy_ne/2]); } @@ -266,7 +266,7 @@ static __global__ void flash_attn_ext_vec( sum = logit_softcap*tanhf(sum); } - if (mask) { + if (mask && (ncols == 1 || ic0 + j < int(ne01.z))) { sum += slope*__half2float(maskh[j*ne11 + i_KQ]); } @@ -412,7 +412,7 @@ static __global__ void flash_attn_ext_vec( #pragma unroll for (int j_VKQ = 0; j_VKQ < ncols; ++j_VKQ) { - if (ncols > 1 && ic0 + j_VKQ >= ne01) { + if (ncols > 1 && ic0 + j_VKQ >= int(ne01.z)) { break; } @@ -479,7 +479,7 @@ static __global__ void flash_attn_ext_vec( if (gridDim.y == 1) { dst_val /= KQ_sum[j_VKQ]; } - dst[(((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y)*D + i0 + tid] = dst_val; + dst[(((sequence*int(ne01.z) + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y)*D + i0 + tid] = dst_val; } } @@ -489,8 +489,8 @@ static __global__ void flash_attn_ext_vec( } - if (gridDim.y != 1 && tid < ncols && (ncols == 1 || ic0 + tid < ne01)) { - dst_meta[((sequence*ne01 + ic0 + tid)*ne02 + head)*gridDim.y + blockIdx.y] = make_float2(KQ_max[tid], KQ_sum[tid]); + if (gridDim.y != 1 && tid < ncols && (ncols == 1 || ic0 + tid < int(ne01.z))) { + dst_meta[((sequence*int(ne01.z) + ic0 + tid)*ne02 + head)*gridDim.y + blockIdx.y] = make_float2(KQ_max[tid], KQ_sum[tid]); } #else GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale, diff --git a/ggml/src/ggml-cuda/fattn-wmma-f16.cu b/ggml/src/ggml-cuda/fattn-wmma-f16.cu index 6c90d6d52b3..0d81f0aae0a 100644 --- a/ggml/src/ggml-cuda/fattn-wmma-f16.cu +++ b/ggml/src/ggml-cuda/fattn-wmma-f16.cu @@ -38,14 +38,14 @@ static __global__ void flash_attn_ext_f16( const float m1, const uint32_t n_head_log2, const float logit_softcap, - const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03, + const int32_t ne00, const uint3 ne01, const int32_t ne02, const int32_t ne03, const int32_t nb01, const int32_t nb02, const int32_t nb03, const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13, const int32_t nb11, const int32_t nb12, const int64_t nb13, const int32_t nb21, const int32_t nb22, const int64_t nb23, const int32_t ne31, const int32_t ne32, const int32_t ne33, const int32_t nb31, const int32_t nb32, const int64_t nb33) { -#if defined(FLASH_ATTN_AVAILABLE) && (__CUDA_ARCH__ == GGML_CUDA_CC_VOLTA || (defined(GGML_HIP_ROCWMMA_FATTN) && defined(GGML_USE_WMMA_FATTN))) +#if defined(FLASH_ATTN_AVAILABLE) && (defined(GGML_HIP_ROCWMMA_FATTN) && defined(GGML_USE_WMMA_FATTN)) // Skip unused kernel variants for faster compilation: if (use_logit_softcap && !(D == 128 || D == 256)) { NO_DEVICE_CODE; @@ -149,7 +149,7 @@ static __global__ void flash_attn_ext_f16( if (i0 + warp_size > D && i >= D) { break; } - KQ[j*D_padded + i] = ic0 + j < ne01 ? Q_f[j*stride_Q + i] * scale : 0.0f; + KQ[j*D_padded + i] = ic0 + j < int(ne01.z) ? Q_f[j*stride_Q + i] * scale : 0.0f; } } @@ -218,7 +218,8 @@ static __global__ void flash_attn_ext_f16( for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += warp_size) { const int k = k0 + threadIdx.x; - KQ_f_tmp[k0/warp_size] += mask ? __half2float(slopeh*maskh[j*(nb31/sizeof(half)) + k_VKQ_0 + k]) : 0.0f; + KQ_f_tmp[k0/warp_size] += mask && ic0 + j < int(ne01.z) ? + __half2float(slopeh*maskh[j*(nb31/sizeof(half)) + k_VKQ_0 + k]) : 0.0f; KQ_max_new = max(KQ_max_new, KQ_f_tmp[k0/warp_size]); } KQ_max_new = warp_reduce_max(KQ_max_new); @@ -270,7 +271,7 @@ static __global__ void flash_attn_ext_f16( for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += warp_size) { const int k = k0 + threadIdx.x; - KQ2_tmp[k0/warp_size] += mask ? slope2*mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2(0.0f, 0.0f); + KQ2_tmp[k0/warp_size] += mask && ic0 + j < int(ne01.z) ? slope2*mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2(0.0f, 0.0f); KQ_max_new = ggml_cuda_hmax2(KQ_max_new, KQ2_tmp[k0/warp_size]); } KQ_max_new = __half2half2(warp_reduce_max(ggml_cuda_hmax(__low2half(KQ_max_new), __high2half(KQ_max_new)))); @@ -431,7 +432,7 @@ static __global__ void flash_attn_ext_f16( #pragma unroll for (int j0 = 0; j0 < ncols; j0 += nwarps) { const int j_VKQ = j0 + threadIdx.y; - if (ic0 + j_VKQ >= ne01) { + if (ic0 + j_VKQ >= int(ne01.z)) { return; } @@ -442,7 +443,7 @@ static __global__ void flash_attn_ext_f16( KQ_rowsum_j = __low2float(KQ_rowsum_h2[j0/nwarps]) + __high2float(KQ_rowsum_h2[j0/nwarps]); } - const int j_dst_unrolled = ((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y; + const int j_dst_unrolled = ((sequence*int(ne01.z) + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y; #pragma unroll for (int i0 = 0; i0 < D; i0 += warp_size) { @@ -481,7 +482,7 @@ static __global__ void flash_attn_ext_f16( ne31, ne32, ne33, nb31, nb32, nb33); NO_DEVICE_CODE; -#endif // defined(FLASH_ATTN_AVAILABLE) && (__CUDA_ARCH__ == GGML_CUDA_CC_VOLTA || (defined(GGML_HIP_ROCWMMA_FATTN) && defined(GGML_USE_WMMA_FATTN))) +#endif // defined(FLASH_ATTN_AVAILABLE) && (defined(GGML_HIP_ROCWMMA_FATTN) && defined(GGML_USE_WMMA_FATTN)) } constexpr int get_max_power_of_2(int x) { diff --git a/ggml/src/ggml-cuda/fattn-wmma-f16.cuh b/ggml/src/ggml-cuda/fattn-wmma-f16.cuh index 7235f1b77ae..cd3bfd4051a 100644 --- a/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-wmma-f16.cuh @@ -2,9 +2,9 @@ #include "common.cuh" -#if (!defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA) || defined(GGML_USE_MUSA) +#if defined(GGML_USE_MUSA) #define GGML_USE_WMMA_FATTN -#endif // (!defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA) || defined(GGML_USE_MUSA) +#endif // defined(GGML_USE_MUSA) #if defined(GGML_HIP_ROCWMMA_FATTN) #if defined(CDNA) && (ROCWMMA_VERSION_MAJOR < 2 || ROCWMMA_VERSION_MINOR > 0 || ROCWMMA_VERSION_PATCH > 0) diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu index 82405991cea..dec01ff8ad2 100644 --- a/ggml/src/ggml-cuda/fattn.cu +++ b/ggml/src/ggml-cuda/fattn.cu @@ -12,13 +12,13 @@ static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ggml_backend_cuda_con const ggml_tensor * Q = dst->src[0]; if constexpr (ncols2 <= 8) { - if (Q->ne[1] <= 8/ncols2) { + if (turing_mma_available(cc) && Q->ne[1] <= 8/ncols2) { ggml_cuda_flash_attn_ext_mma_f16_case(ctx, dst); return; } } - if (Q->ne[1] <= 16/ncols2) { + if (turing_mma_available(cc) && Q->ne[1] <= 16/ncols2) { ggml_cuda_flash_attn_ext_mma_f16_case(ctx, dst); return; } @@ -41,7 +41,7 @@ static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2(ggml_backend_cuda_con float max_bias = 0.0f; memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float)); - const bool use_gqa_opt = mask && max_bias == 0.0f; + const bool use_gqa_opt = mask && max_bias == 0.0f && K->ne[1] % FATTN_KQ_STRIDE == 0; GGML_ASSERT(Q->ne[2] % K->ne[2] == 0); const int gqa_ratio = Q->ne[2] / K->ne[2]; @@ -275,8 +275,8 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const // For small batch sizes the vector kernel may be preferable over the kernels optimized for large batch sizes: const bool can_use_vector_kernel = Q->ne[0] <= 256 && Q->ne[0] % 64 == 0 && K->ne[1] % FATTN_KQ_STRIDE == 0; - // If Turing tensor cores available, use them: - if (turing_mma_available(cc) && K->ne[1] % FATTN_KQ_STRIDE == 0 && Q->ne[0] != 40 && Q->ne[0] != 72) { + // If Turing tensor cores are available, use them: + if (turing_mma_available(cc) && Q->ne[0] != 40 && Q->ne[0] != 72) { if (can_use_vector_kernel) { if (!ggml_is_quantized(K->type) && !ggml_is_quantized(V->type)) { if (cc >= GGML_CUDA_CC_ADA_LOVELACE && Q->ne[1] == 1 && Q->ne[3] == 1 && !(gqa_ratio > 4 && K->ne[1] >= 8192)) { @@ -297,7 +297,21 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const return BEST_FATTN_KERNEL_VEC; } } + return BEST_FATTN_KERNEL_MMA_F16; + } + if (volta_mma_available(cc) && Q->ne[0] != 40 && Q->ne[0] != 72) { + int gqa_ratio_eff = 1; + const int ncols2_max = Q->ne[0] == 576 ? 16 : 8; + while (gqa_ratio % (2*gqa_ratio_eff) == 0 && gqa_ratio_eff < ncols2_max) { + gqa_ratio_eff *= 2; + } + if (can_use_vector_kernel && Q->ne[1] * gqa_ratio_eff <= 2) { + return BEST_FATTN_KERNEL_VEC; + } + if (Q->ne[1] * gqa_ratio_eff <= 16) { + return BEST_FATTN_KERNEL_TILE; // On Volta tensor cores are only faster for sufficiently large matrices. + } return BEST_FATTN_KERNEL_MMA_F16; } diff --git a/ggml/src/ggml-cuda/mma.cuh b/ggml/src/ggml-cuda/mma.cuh index 0ed42e87d3d..6ea7a809a47 100644 --- a/ggml/src/ggml-cuda/mma.cuh +++ b/ggml/src/ggml-cuda/mma.cuh @@ -68,10 +68,31 @@ static __device__ __forceinline__ half2 ggml_cuda_movmatrix(const half2 x) { namespace ggml_cuda_mma { + // Some architectures like Volta or CDNA3 perform multiple matrix multiplications per warp in parallel, + // effectively the warp is being split into subgroups of threads that each perform a single mma instruction. + // In those cases the data can be split in different ways across the warp. + enum data_layout { + // By default the data uses the I direction as its major dimension and the J direction as its minor dimension. + // For the A/C matrices this means I major == row major, J major == column major. + // For the B matrix this means I major == column major, J major == row major. + // MIRRORED == Each data value is held exactly once per thread subgroup. + DATA_LAYOUT_I_MAJOR = 0, // Always used for Turing, Ampere, Ada Lovelace, consumer Blackwell. + DATA_LAYOUT_I_MAJOR_MIRRORED = 10, + DATA_LAYOUT_J_MAJOR_MIRRORED = 20, + }; + // Implemented mma combinations are: + // - (I_MAJOR, I_MAJOR) -> I_MAJOR + // - (I_MAJOR, I_MAJOR_MIRRORED) -> I_MAJOR + // - (I_MAJOR, J_MAJOR_MIRRORED) -> I_MAJOR + + template + struct tile {}; + template - struct tile { - static constexpr int I = I_; - static constexpr int J = J_; + struct tile { + static constexpr int I = I_; + static constexpr int J = J_; + static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR; #if defined(AMD_MFMA_AVAILABLE) static constexpr int ne = I * J / 64; @@ -131,9 +152,9 @@ namespace ggml_cuda_mma { static __device__ __forceinline__ int get_i(const int l) { if constexpr (I == 32 && J == 8) { #ifdef GGML_CUDA_MMA_NO_VOLTA_PERM - return (((threadIdx.x % 16) / 4) * 8) | ((threadIdx.x / 16) * 4) | (l & 2) | (threadIdx.x % 2); + return (((threadIdx.x % 16) / 4) * 8) + ((threadIdx.x / 16) * 4) + (l & 2) + (threadIdx.x % 2); #else - return (l & 2) | (threadIdx.x & ~2); + return (l & 2) + (threadIdx.x & ~2); #endif // GGML_CUDA_MMA_NO_VOLTA_PERM } else { NO_DEVICE_CODE; @@ -143,7 +164,7 @@ namespace ggml_cuda_mma { static __device__ __forceinline__ int get_j(const int l) { if constexpr (I == 32 && J == 8) { - return (threadIdx.x & 2) | (l & (4 + 1)); + return (threadIdx.x & 2) + (l & (4 + 1)); } else { NO_DEVICE_CODE; return -1; @@ -196,9 +217,9 @@ namespace ggml_cuda_mma { } else if constexpr (I == 8 && J == 8) { return threadIdx.x / 4; } else if constexpr (I == 16 && J == 8) { - return ((l / 2) * 8) | (threadIdx.x / 4); + return ((l / 2) * 8) + (threadIdx.x / 4); } else if constexpr (I == 16 && J == 16) { - return (((l / 2) % 2) * 8) | (threadIdx.x / 4); + return (((l / 2) % 2) * 8) + (threadIdx.x / 4); } else if constexpr (I == 32 && J == 8) { return tile<16, 8, T>::get_i(l); // Memory layout simply repeated with same pattern in i direction. } else { @@ -211,11 +232,11 @@ namespace ggml_cuda_mma { if constexpr (I == 8 && J == 4) { return threadIdx.x % 4; } else if constexpr (I == 8 && J == 8) { - return (l * 4) | (threadIdx.x % 4); + return (l * 4) + (threadIdx.x % 4); } else if constexpr (I == 16 && J == 8) { - return ((threadIdx.x % 4) * 2) | (l % 2); + return ((threadIdx.x % 4) * 2) + (l % 2); } else if constexpr (I == 16 && J == 16) { - return ((l / 4) * 8) | ((threadIdx.x % 4) * 2) | (l % 2); + return ((l / 4) * 8) + ((threadIdx.x % 4) * 2) + (l % 2); } else if constexpr (I == 32 && J == 8) { return tile<16, 8, T>::get_j(l); // Memory layout simply repeated with same pattern in i direction. } else { @@ -227,26 +248,24 @@ namespace ggml_cuda_mma { }; template - struct tile { - static constexpr int I = I_; - static constexpr int J = J_; + struct tile { + static constexpr int I = I_; + static constexpr int J = J_; + static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR; #if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA - static constexpr int ne = I == 8 && J == 8 ? I * J / (WARP_SIZE/4) : I * J / WARP_SIZE; + static constexpr int ne = I * J / WARP_SIZE; half2 x[ne] = {{0.0f, 0.0f}}; static constexpr __device__ bool supported() { - if (I == 8 && J == 8) return true; - if (I == 32 && J == 8) return true; + if (I == 32 && J == 4) return true; return false; } static __device__ __forceinline__ int get_i(const int l) { - if constexpr (I == 8 && J == 8) { - return ((threadIdx.x / 16) * 4) | (threadIdx.x % 4); - } else if constexpr (I == 32 && J == 8) { + if constexpr (I == 32 && J == 4) { #ifdef GGML_CUDA_MMA_NO_VOLTA_PERM - return (((threadIdx.x % 16) / 4) * 8) | ((threadIdx.x / 16) * 4) | (threadIdx.x % 4); + return (((threadIdx.x % 16) / 4) * 8) + ((threadIdx.x / 16) * 4) + (threadIdx.x % 4); #else return threadIdx.x; #endif // GGML_CUDA_MMA_NO_VOLTA_PERM @@ -257,7 +276,7 @@ namespace ggml_cuda_mma { } static __device__ __forceinline__ int get_j(const int l) { - if constexpr ((I == 8 || I == 32) && J == 8) { + if constexpr (I == 32 && J == 4) { return l; } else { NO_DEVICE_CODE; @@ -307,11 +326,11 @@ namespace ggml_cuda_mma { if constexpr (I == 8 && J == 8) { return threadIdx.x / 4; } else if constexpr (I == 16 && J == 4) { - return (l * 8) | (threadIdx.x / 4); + return (l * 8) + (threadIdx.x / 4); } else if constexpr (I == 16 && J == 8) { - return ((l % 2) * 8) | (threadIdx.x / 4); + return ((l % 2) * 8) + (threadIdx.x / 4); } else if constexpr (I == 32 && J == 8) { - return ((l / 4) * 16) | ((l % 2) * 8) | (threadIdx.x / 4); + return ((l / 4) * 16) + ((l % 2) * 8) + (threadIdx.x / 4); } else { NO_DEVICE_CODE; return -1; @@ -320,13 +339,13 @@ namespace ggml_cuda_mma { static __device__ __forceinline__ int get_j(const int l) { if constexpr (I == 8 && J == 8) { - return (l * 4) | (threadIdx.x % 4); + return (l * 4) + (threadIdx.x % 4); } else if constexpr (I == 16 && J == 4) { return threadIdx.x % 4; } else if constexpr (I == 16 && J == 8) { - return ((l / 2) * 4) | (threadIdx.x % 4); + return ((l / 2) * 4) + (threadIdx.x % 4); } else if constexpr (I == 32 && J == 8) { - return ((l & 2) * 2) | (threadIdx.x % 4); + return ((l & 2) * 2) + (threadIdx.x % 4); } else { NO_DEVICE_CODE; return -1; @@ -336,14 +355,15 @@ namespace ggml_cuda_mma { }; template - struct tile { - static constexpr int I = I_; - static constexpr int J = J_; + struct tile { + static constexpr int I = I_; + static constexpr int J = J_; + static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR; + static constexpr int ne = I * J / WARP_SIZE; -#if defined(AMD_WMMA_AVAILABLE) - static constexpr int ne = I * J / 32; nv_bfloat162 x[ne] = {{0.0f, 0.0f}}; +#if defined(AMD_WMMA_AVAILABLE) static constexpr __device__ bool supported() { if (I == 16 && J == 8) return true; return false; @@ -367,9 +387,6 @@ namespace ggml_cuda_mma { } } #else - static constexpr int ne = I * J / WARP_SIZE; - nv_bfloat162 x[ne] = {{0.0f, 0.0f}}; - static constexpr __device__ bool supported() { if (I == 8 && J == 8) return true; if (I == 16 && J == 4) return true; @@ -381,9 +398,9 @@ namespace ggml_cuda_mma { if constexpr (I == 8 && J == 8) { return threadIdx.x / 4; } else if constexpr (I == 16 && J == 4) { - return (l * 8) | (threadIdx.x / 4); + return (l * 8) + (threadIdx.x / 4); } else if constexpr (I == 16 && J == 8) { - return ((l % 2) * 8) | (threadIdx.x / 4); + return ((l % 2) * 8) + (threadIdx.x / 4); } else { NO_DEVICE_CODE; return -1; @@ -392,11 +409,11 @@ namespace ggml_cuda_mma { static __device__ __forceinline__ int get_j(const int l) { if constexpr (I == 8 && J == 8) { - return (l * 4) | (threadIdx.x % 4); + return (l * 4) + (threadIdx.x % 4); } else if constexpr (I == 16 && J == 4) { return threadIdx.x % 4; } else if constexpr (I == 16 && J == 8) { - return ((l / 2) * 4) | (threadIdx.x % 4); + return ((l / 2) * 4) + (threadIdx.x % 4); } else { NO_DEVICE_CODE; return -1; @@ -405,6 +422,73 @@ namespace ggml_cuda_mma { #endif // defined(AMD_WMMA_AVAILABLE) }; + template + struct tile { + static constexpr int I = I_; + static constexpr int J = J_; + static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR_MIRRORED; + static constexpr int ne = I * J / (WARP_SIZE/4); + + half2 x[ne] = {{0.0f, 0.0f}}; + + static constexpr __device__ bool supported() { + if (I == 8 && J == 4) return true; + return false; + } + + static __device__ __forceinline__ int get_i(const int /*l*/) { + if constexpr (I == 8 && J == 4) { + return ((threadIdx.x / 16) * 4) + (threadIdx.x % 4); + } else { + NO_DEVICE_CODE; + return -1; + } + } + + static __device__ __forceinline__ int get_j(const int l) { + if constexpr (I == 8 && J == 4) { + return l; + } else { + NO_DEVICE_CODE; + return -1; + } + } + }; + + template + struct tile { + static constexpr int I = I_; + static constexpr int J = J_; + static constexpr data_layout dl = DATA_LAYOUT_J_MAJOR_MIRRORED; + static constexpr int ne = I * J / (WARP_SIZE/4); + + half2 x[ne] = {{0.0f, 0.0f}}; + + static constexpr __device__ bool supported() { + if (I == 8 && J == 4) return true; + return false; + } + + static __device__ __forceinline__ int get_i(const int l) { + if constexpr (I == 8 && J == 4) { + return ((l / 2) * 4) + (threadIdx.x % 4); + } else { + NO_DEVICE_CODE; + return -1; + } + } + + static __device__ __forceinline__ int get_j(const int l) { + if constexpr (I == 8 && J == 4) { + return ((threadIdx.x / 16) * 2) + (l % 2); + } else { + NO_DEVICE_CODE; + return -1; + } + } + }; + +#if defined(TURING_MMA_AVAILABLE) template static __device__ __forceinline__ tile get_half2(const tile & tile_float) { tile ret; @@ -422,9 +506,26 @@ namespace ggml_cuda_mma { return ret; } +#else // Volta + template + static __device__ __forceinline__ tile get_half2(const tile & tile_float) { + tile ret; +#pragma unroll + for (int l0 = 0; l0 < tile_float.ne; l0 += 4) { + ret.x[l0/2 + 0] = make_half2(tile_float.x[l0 + 0], tile_float.x[l0 + 1]); + ret.x[l0/2 + 1] = make_half2(tile_float.x[l0 + 2], tile_float.x[l0 + 3]); + + // On Volta FP16 and FP32 tiles have a different memory layout, + // for the conversion threads with an offset of 2 need to exchange half their values: + ret.x[l0/2 + (((threadIdx.x % 4) / 2) ^ 1)] = __shfl_xor_sync( + 0xFFFFFFFF, ret.x[l0/2 + (((threadIdx.x % 4) / 2) ^ 1)], 2, WARP_SIZE); + } + return ret; + } +#endif // defined(TURING_MMA_AVAILABLE) - template - static __device__ __forceinline__ void load_generic(tile & t, const T * __restrict__ xs0, const int stride) { + template + static __device__ __forceinline__ void load_generic(tile & t, const T * __restrict__ xs0, const int stride) { #if defined(AMD_MFMA_AVAILABLE) if constexpr (I == 64 && J == 2) { // Special tile size to load <16, 4> as <16, 8> #pragma unroll @@ -511,18 +612,6 @@ namespace ggml_cuda_mma { : "=r"(xi[0]), "=r"(xi[1]), "=r"(xi[2]), "=r"(xi[3]) : "l"(xs)); #else -#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA - GGML_UNUSED_VARS(t, xs0, stride); - NO_DEVICE_CODE; -#else - load_generic(t, xs0, stride); -#endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA -#endif // TURING_MMA_AVAILABLE - } - - template - static __device__ __forceinline__ void load_ldmatrix( - tile<32, 8, T> & t, const T * __restrict__ xs0, const int stride) { #if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA #if 1 // TODO: more generic handling @@ -533,9 +622,31 @@ namespace ggml_cuda_mma { load_generic(t, xs0, stride); #endif // 1 #else - tile<16, 8, T> * t16 = (tile<16, 8, T> *) &t; - load_ldmatrix(t16[0], xs0 + 0*stride, stride); - load_ldmatrix(t16[1], xs0 + 16*stride, stride); + load_generic(t, xs0, stride); +#endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA +#endif // TURING_MMA_AVAILABLE + } + + static __device__ __forceinline__ void load_ldmatrix( + tile<8, 4, half2, DATA_LAYOUT_I_MAJOR_MIRRORED> & t, const half2 * __restrict__ xs0, const int stride) { + ggml_cuda_memcpy_1<4*sizeof(half2)>(t.x, xs0 + t.get_i(0)*stride); + } + + static __device__ __forceinline__ void load_ldmatrix( + tile<8, 4, half2, DATA_LAYOUT_J_MAJOR_MIRRORED> & t, const half2 * __restrict__ xs0, const int stride) { +#pragma unroll + for (int l0 = 0; l0 < t.ne; l0 += 2) { + ggml_cuda_memcpy_1<2*sizeof(half2)>(t.x + l0, xs0 + t.get_i(l0)*stride + t.get_j(l0)); + } + } + + static __device__ __forceinline__ void load_ldmatrix( + tile<32, 4, half2> & t, const half2 * __restrict__ xs0, const int stride) { +#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA + ggml_cuda_memcpy_1<4*sizeof(half2)>(t.x, xs0 + t.get_i(0)*stride); +#else + GGML_UNUSED_VARS(t, xs0, stride); + NO_DEVICE_CODE; #endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA } @@ -860,14 +971,14 @@ namespace ggml_cuda_mma { template static __device__ __forceinline__ void mma( tile<32, J, T1> & D, const tile<32, K, T2> & A, const tile & B) { - tile<16, J, T1> * D16 = (tile<16, J, T1> *) &D; - tile<16, K, T2> * A16 = (tile<16, K, T2> *) &A; + tile <16, J, T1> * D16 = reinterpret_cast< tile<16, J, T1> *>(&D); + const tile<16, K, T2> * A16 = reinterpret_cast *>(&A); mma(D16[0], A16[0], B); mma(D16[1], A16[1], B); } static __device__ __forceinline__ void mma( - tile<32, 8, float> & D, const tile<32, 8, half2> & A, const tile<8, 8, half2> & B) { + tile<32, 8, float> & D, const tile<32, 4, half2> & A, const tile<8, 4, half2, DATA_LAYOUT_I_MAJOR_MIRRORED> & B) { #if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA const int * Axi = (const int *) A.x; const int * Bxi = (const int *) B.x; @@ -880,20 +991,30 @@ namespace ggml_cuda_mma { "{%0, %1, %2, %3, %4, %5, %6, %7}, {%8, %9}, {%10, %11}, {%0, %1, %2, %3, %4, %5, %6, %7};" : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3]), "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7]) : "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[2]), "r"(Bxi[3])); - asm("mma.sync.aligned.m8n8k4.row.col.f32.f16.f16.f32 " - "{%0, %1, %2, %3, %4, %5, %6, %7}, {%8, %9}, {%10, %11}, {%0, %1, %2, %3, %4, %5, %6, %7};" - : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3]), "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7]) - : "r"(Axi[4]), "r"(Axi[5]), "r"(Bxi[4]), "r"(Bxi[5])); - asm("mma.sync.aligned.m8n8k4.row.col.f32.f16.f16.f32 " - "{%0, %1, %2, %3, %4, %5, %6, %7}, {%8, %9}, {%10, %11}, {%0, %1, %2, %3, %4, %5, %6, %7};" - : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3]), "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7]) - : "r"(Axi[6]), "r"(Axi[7]), "r"(Bxi[6]), "r"(Bxi[7])); #else - tile <16, 8, float> * D16 = reinterpret_cast *>(&D); - const tile<16, 8, half2> * A16 = reinterpret_cast *>(&A); - mma(D16[0], A16[0], B); - mma(D16[1], A16[1], B); -#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE + GGML_UNUSED_VARS(D, A, B); + NO_DEVICE_CODE; +#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA + } + + static __device__ __forceinline__ void mma( + tile<32, 4, half2> & D, const tile<32, 4, half2> & A, const tile<8, 4, half2, DATA_LAYOUT_J_MAJOR_MIRRORED> & B) { +#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA + const int * Axi = (const int *) A.x; + const int * Bxi = (const int *) B.x; + int * Dxi = (int *) D.x; + asm("mma.sync.aligned.m8n8k4.row.row.f16.f16.f16.f16 " + "{%0, %1, %2, %3}, {%4, %5}, {%6, %7}, {%0, %1, %2, %3};" + : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3]) + : "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[0]), "r"(Bxi[1])); + asm("mma.sync.aligned.m8n8k4.row.row.f16.f16.f16.f16 " + "{%0, %1, %2, %3}, {%4, %5}, {%6, %7}, {%0, %1, %2, %3};" + : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3]) + : "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[2]), "r"(Bxi[3])); +#else + GGML_UNUSED_VARS(D, A, B); + NO_DEVICE_CODE; +#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA } static __device__ __forceinline__ void mma( diff --git a/ggml/src/ggml-cuda/mmf.cuh b/ggml/src/ggml-cuda/mmf.cuh index c2a0a2e42fe..e1c695c5c0f 100644 --- a/ggml/src/ggml-cuda/mmf.cuh +++ b/ggml/src/ggml-cuda/mmf.cuh @@ -37,23 +37,19 @@ static __global__ void mul_mat_f( typedef tile<16, 8, T> tile_A; typedef tile tile_B; typedef tile<16, tile_C_J, float> tile_C; - - constexpr bool a_supported = tile_A::supported(); - constexpr bool b_supported = tile_B::supported(); - constexpr bool c_supported = tile_C::supported(); - constexpr bool supported = a_supported && b_supported && c_supported; #else - constexpr bool I_16_supported = tile<16, 8, T>::supported() && tile<16, 8, float>::supported(); - constexpr bool I_32_supported = tile<32, 8, T>::supported() && tile<32, 8, float>::supported(); - constexpr bool supported = I_16_supported || I_32_supported; - - constexpr int I_preferred = I_16_supported ? 16 : 32; // For Turing MMA both work but 16 is ~1% faster. - - typedef tile tile_A; - typedef tile<8, 8, T> tile_B; - typedef tile tile_C; +#ifdef VOLTA_MMA_AVAILABLE + if constexpr (!std::is_same_v) {NO_DEVICE_CODE;} else { + typedef tile<32, 4, T, DATA_LAYOUT_I_MAJOR> tile_A; + typedef tile< 8, 4, T, DATA_LAYOUT_I_MAJOR_MIRRORED> tile_B; + typedef tile<32, 8, float, DATA_LAYOUT_I_MAJOR> tile_C; +#else + typedef tile<16, 8, T> tile_A; + typedef tile<8, 8, T> tile_B; + typedef tile<16, 8, float> tile_C; +#endif // VOLTA_MMA_AVAILABLE #endif // defined(AMD_WMMA_AVAILABLE) - if constexpr (!supported) { + if constexpr (!tile_A::supported() || !tile_B::supported() || !tile_C::supported()) { NO_DEVICE_CODE; return; } @@ -248,6 +244,9 @@ static __global__ void mul_mat_f( } } } +#ifdef VOLTA_MMA_AVAILABLE + } +#endif //VOLTA_MMA_AVAILABLE #else GGML_UNUSED_VARS(x, y, ids, dst, ncols, ncols_dst_total, nchannels_dst, stride_row, stride_col_y, stride_col_dst, @@ -278,27 +277,24 @@ static __global__ void mul_mat_f_ids( typedef tile<16, 8, T> tile_A; typedef tile tile_B; typedef tile<16, tile_C_J, float> tile_C; - - constexpr bool a_supported = tile_A::supported(); - constexpr bool b_supported = tile_B::supported(); - constexpr bool c_supported = tile_C::supported(); - constexpr bool supported = a_supported && b_supported && c_supported; #else - constexpr bool I_16_supported = tile<16, 8, T>::supported() && tile<16, 8, float>::supported(); - constexpr bool I_32_supported = tile<32, 8, T>::supported() && tile<32, 8, float>::supported(); - constexpr bool supported = I_16_supported || I_32_supported; - - constexpr int I_preferred = I_16_supported ? 16 : 32; // For Turing MMA both work but 16 is ~1% faster. - - typedef tile tile_A; - typedef tile<8, 8, T> tile_B; - typedef tile tile_C; +#ifdef VOLTA_MMA_AVAILABLE + if constexpr (!std::is_same_v) {NO_DEVICE_CODE;} else { + typedef tile<32, 4, T, DATA_LAYOUT_I_MAJOR> tile_A; + typedef tile< 8, 4, T, DATA_LAYOUT_I_MAJOR_MIRRORED> tile_B; + typedef tile<32, 8, float, DATA_LAYOUT_I_MAJOR> tile_C; +#else + typedef tile<16, 8, T> tile_A; + typedef tile<8, 8, T> tile_B; + typedef tile<16, 8, float> tile_C; +#endif // VOLTA_MMA_AVAILABLE #endif // defined(AMD_WMMA_AVAILABLE) - if constexpr (!supported) { + if constexpr (!tile_A::supported() || !tile_B::supported() || !tile_C::supported()) { NO_DEVICE_CODE; return; } + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); constexpr int tile_k_padded = warp_size + 4; constexpr int ntA = rows_per_block / tile_A::I; @@ -517,6 +513,9 @@ static __global__ void mul_mat_f_ids( } } } +#ifdef VOLTA_MMA_AVAILABLE + } +#endif // VOLTA_MMA_AVAILABLE #else GGML_UNUSED_VARS(x, y, ids_src_compact, ids_dst_compact, expert_bounds, dst, ncols, ncols_dst_total, nchannels_dst, stride_row, stride_col_y, stride_col_dst, diff --git a/ggml/src/ggml-metal/ggml-metal-device.cpp b/ggml/src/ggml-metal/ggml-metal-device.cpp index 329500a03e0..c647baef878 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.cpp +++ b/ggml/src/ggml-metal/ggml-metal-device.cpp @@ -50,7 +50,7 @@ void ggml_metal_pipelines_add(ggml_metal_pipelines_t ppls, const char * name, gg } ggml_metal_pipeline_t ggml_metal_pipelines_get(ggml_metal_pipelines_t ppls, const char * name) { - if (ppls->data.find(name) == ppls->data.end()) { + if (ppls->data.find(name) == ppls->data.end()) { return nullptr; } diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m index 62bc4ba45fc..4d2bfcf91c6 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.m +++ b/ggml/src/ggml-metal/ggml-metal-device.m @@ -146,6 +146,8 @@ int ggml_metal_pipeline_max_theads_per_threadgroup(ggml_metal_pipeline_t pipelin id device; ggml_metal_pipelines_t pipelines; // cache of compiled pipelines + + NSLock * lock; }; ggml_metal_library_t ggml_metal_library_init(ggml_metal_device_t dev) { @@ -296,9 +298,10 @@ ggml_metal_library_t ggml_metal_library_init(ggml_metal_device_t dev) { ggml_metal_library_t res = calloc(1, sizeof(struct ggml_metal_library)); - res->obj = library; - res->device = device; + res->obj = library; + res->device = device; res->pipelines = ggml_metal_pipelines_init(); + res->lock = [NSLock new]; return res; } @@ -365,6 +368,7 @@ ggml_metal_library_t ggml_metal_library_init_from_source(ggml_metal_device_t dev res->obj = library; res->device = device; res->pipelines = ggml_metal_pipelines_init(); + res->lock = [NSLock new]; return res; } @@ -380,20 +384,27 @@ void ggml_metal_library_free(ggml_metal_library_t lib) { ggml_metal_pipelines_free(lib->pipelines); + [lib->lock release]; + free(lib); } ggml_metal_pipeline_t ggml_metal_library_get_pipeline(ggml_metal_library_t lib, const char * name) { - return ggml_metal_pipelines_get(lib->pipelines, name); + [lib->lock lock]; + + ggml_metal_pipeline_t res = ggml_metal_pipelines_get(lib->pipelines, name); + + [lib->lock unlock]; + + return res; } ggml_metal_pipeline_t ggml_metal_library_compile_pipeline(ggml_metal_library_t lib, const char * base, const char * name, ggml_metal_cv_t cv) { - // note: the pipelines are cached in the library per device, so they are shared across all metal contexts - ggml_critical_section_start(); + [lib->lock lock]; - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); + ggml_metal_pipeline_t res = ggml_metal_pipelines_get(lib->pipelines, name); if (res) { - ggml_critical_section_end(); + [lib->lock unlock]; return res; } @@ -414,7 +425,7 @@ ggml_metal_pipeline_t ggml_metal_library_compile_pipeline(ggml_metal_library_t l mtl_function = [lib->obj newFunctionWithName:base_func constantValues:cv->obj error:&error]; } if (!mtl_function) { - ggml_critical_section_end(); + [lib->lock unlock]; GGML_LOG_ERROR("%s: failed to compile pipeline: base = '%s', name = '%s'\n", __func__, base, name); if (error) { @@ -433,7 +444,7 @@ ggml_metal_pipeline_t ggml_metal_library_compile_pipeline(ggml_metal_library_t l (int) res->obj.threadExecutionWidth); if (res->obj.maxTotalThreadsPerThreadgroup == 0 || res->obj.threadExecutionWidth == 0) { - ggml_critical_section_end(); + [lib->lock unlock]; GGML_LOG_ERROR("%s: incompatible pipeline %s\n", __func__, name); @@ -443,7 +454,7 @@ ggml_metal_pipeline_t ggml_metal_library_compile_pipeline(ggml_metal_library_t l ggml_metal_pipelines_add(lib->pipelines, name, res); } - ggml_critical_section_end(); + [lib->lock unlock]; return res; } diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 95966ce1d8e..f917a745d5a 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -1227,6 +1227,7 @@ struct vk_op_topk_push_constants { uint32_t orig_ncols; uint32_t ncols_input; uint32_t ncols_output; + uint32_t k; uint32_t nrows; uint32_t first_pass; uint32_t last_pass; @@ -1673,6 +1674,14 @@ class vk_perf_logger { timings[name.str()].push_back(time); return; } + if (node->op == GGML_OP_TOP_K) { + std::stringstream name; + name << ggml_op_name(node->op) << + " K=" << node->ne[0] << + " (" << node->src[0]->ne[0] << "," << node->src[0]->ne[1] << "," << node->src[0]->ne[2] << "," << node->src[0]->ne[3] << ")"; + timings[name.str()].push_back(time); + return; + } timings[ggml_op_name(node->op)].push_back(time); } private: @@ -10345,17 +10354,8 @@ static void ggml_vk_topk(ggml_backend_vk_context * ctx, vk_context& subctx, cons uint32_t nrows = ggml_nrows(src0); uint32_t k = dst->ne[0]; - vk_op_topk_push_constants pc { ncols, ncols, k, nrows, 0, 0 }; + vk_op_topk_push_constants pc { ncols, ncols, ncols, k, nrows, 0, 0 }; - // Reserve space for ivec2 per element, double buffered - const size_t dbl_buf_size = size_t{ncols} * nrows * 2 * sizeof(int); - const size_t x_sz = dbl_buf_size * 2; - uint32_t dbl_buf_index = 0; - - if (ctx->prealloc_size_x < x_sz) { - ctx->prealloc_size_x = x_sz; - ggml_vk_preallocate_buffers(ctx, subctx); - } if (ctx->prealloc_x_need_sync) { ggml_vk_sync_buffers(ctx, subctx); } @@ -10370,8 +10370,9 @@ static void ggml_vk_topk(ggml_backend_vk_context * ctx, vk_context& subctx, cons // largest elements. Repeat until we have the top K elements. // Need to do at least one iteration to write out the results. bool done_one_iter = false; + uint32_t dbl_buf_index = 0; + size_t dbl_buf_size; while (num_elements > k || !done_one_iter) { - done_one_iter = true; // Prefer going as small as num_topk_pipelines - 3 for perf reasons. // But if K is larger, then we need a larger workgroup @@ -10411,6 +10412,21 @@ static void ggml_vk_topk(ggml_backend_vk_context * ctx, vk_context& subctx, cons // Number of elements remaining after this pass uint32_t num_dst_elements = (num_elements / pipeline->wg_denoms[0]) * k + std::min(k, num_elements % pipeline->wg_denoms[0]); + pc2.ncols_output = num_dst_elements; + + if (!done_one_iter) { + // Reserve space for ivec2 per element, double buffered + // K per workgroup per row + dbl_buf_size = num_dst_elements * nrows * 2 * sizeof(int); + dbl_buf_size = ROUNDUP_POW2(dbl_buf_size, ctx->device->properties.limits.minStorageBufferOffsetAlignment); + const size_t x_sz = dbl_buf_size * 2; + + if (ctx->prealloc_size_x < x_sz) { + ctx->prealloc_size_x = x_sz; + ggml_vk_preallocate_buffers(ctx, subctx); + } + } + vk_subbuffer src_buf; vk_subbuffer dst_buf; @@ -10436,6 +10452,7 @@ static void ggml_vk_topk(ggml_backend_vk_context * ctx, vk_context& subctx, cons if (num_elements > k) { ggml_vk_sync_buffers(ctx, subctx); } + done_one_iter = true; } ctx->prealloc_x_need_sync = true; } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/topk_argsort.comp b/ggml/src/ggml-vulkan/vulkan-shaders/topk_argsort.comp index cd858b7d326..49d4ab8e7c0 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/topk_argsort.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/topk_argsort.comp @@ -19,6 +19,7 @@ layout (push_constant) uniform parameter { uint orig_ncols; uint ncols_input; uint ncols_output; + uint k; uint nrows; uint first_pass; uint last_pass; @@ -36,7 +37,7 @@ void topk(bool needs_bounds_check, const uint row) { const uint row_offset = row * p.ncols_input; dst_row[col] = ivec2(gl_GlobalInvocationID.x, floatBitsToInt(data_a[row_offset + gl_GlobalInvocationID.x])); } else { - const uint row_offset = row * p.orig_ncols; + const uint row_offset = row * p.ncols_input; dst_row[col] = data_s[row_offset + gl_GlobalInvocationID.x]; } } else { @@ -44,7 +45,7 @@ void topk(bool needs_bounds_check, const uint row) { } barrier(); - if (p.ncols_output == 1) { + if (p.k == 1) { // Fast path for single output - just do a max reduction [[unroll]] for (int s = BLOCK_SIZE / 2; s >= 1; s /= 2) { if (col < s) { @@ -84,13 +85,17 @@ void topk(bool needs_bounds_check, const uint row) { } } - if (col < p.ncols_output && gl_GlobalInvocationID.x < p.orig_ncols) { + if (col < p.k) { if (p.last_pass != 0) { - const uint row_offset = row * p.ncols_output; - data_d[row_offset + col] = dst_row[col].x; + if (gl_GlobalInvocationID.x < p.ncols_input) { + const uint row_offset = row * p.k; + data_d[row_offset + col] = dst_row[col].x; + } } else { - const uint row_offset = row * p.orig_ncols + gl_WorkGroupID.x * p.ncols_output; - data_t[row_offset + col] = dst_row[col]; + if (gl_WorkGroupID.x * p.k + col < p.ncols_output) { + const uint row_offset = row * p.ncols_output + gl_WorkGroupID.x * p.k; + data_t[row_offset + col] = dst_row[col]; + } } } } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/topk_nary_search.comp b/ggml/src/ggml-vulkan/vulkan-shaders/topk_nary_search.comp index c902e60237a..f794285ee15 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/topk_nary_search.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/topk_nary_search.comp @@ -25,6 +25,7 @@ layout (push_constant) uniform parameter { uint orig_ncols; uint ncols_input; uint ncols_output; + uint k; uint nrows; uint first_pass; uint last_pass; @@ -60,7 +61,7 @@ void topk(const uint row) { const uint row_offset = row * p.ncols_input; dst_row[tid] = ivec2(gl_GlobalInvocationID.x, floatBitsToInt(data_a[row_offset + gl_GlobalInvocationID.x])); } else { - const uint row_offset = row * p.orig_ncols; + const uint row_offset = row * p.ncols_input; dst_row[tid] = data_s[row_offset + gl_GlobalInvocationID.x]; } } else { @@ -68,7 +69,7 @@ void topk(const uint row) { } barrier(); - if (p.ncols_output == 1) { + if (p.k == 1) { // Fast path for single output - just do a max reduction [[unroll]] for (int s = BLOCK_SIZE / 2; s >= 1; s /= 2) { if (tid < s) { @@ -98,7 +99,7 @@ void topk(const uint row) { uint range_max = 0xFF800000; // How many are above the current range, and how many we need to find. uint total = 0; - uint limit = min(p.ncols_output, p.ncols_input - gl_WorkGroupID.x * BLOCK_SIZE); + uint limit = min(p.k, p.ncols_input - gl_WorkGroupID.x * BLOCK_SIZE); while (mask != 0) { barrier(); @@ -139,7 +140,7 @@ void topk(const uint row) { range_max = range_min + ((min_idx + 1) << shift); range_min = range_min + (min_idx << shift); - if (total == p.ncols_output) { + if (total == p.k) { break; } total -= counts[min_idx]; @@ -179,13 +180,17 @@ void topk(const uint row) { barrier(); } - if (tid < p.ncols_output && gl_GlobalInvocationID.x < p.orig_ncols) { + if (tid < p.k) { if (p.last_pass != 0) { - const uint row_offset = row * p.ncols_output; - data_d[row_offset + tid] = dst_row[tid].x; + if (gl_GlobalInvocationID.x < p.ncols_input) { + const uint row_offset = row * p.k; + data_d[row_offset + tid] = dst_row[tid].x; + } } else { - const uint row_offset = row * p.orig_ncols + gl_WorkGroupID.x * p.ncols_output; - data_t[row_offset + tid] = dst_row[tid]; + if (gl_WorkGroupID.x * p.k + tid < p.ncols_output) { + const uint row_offset = row * p.ncols_output + gl_WorkGroupID.x * p.k; + data_t[row_offset + tid] = dst_row[tid]; + } } } } diff --git a/ggml/src/ggml-webgpu/CMakeLists.txt b/ggml/src/ggml-webgpu/CMakeLists.txt index c6a95d51512..3ccce58aa39 100644 --- a/ggml/src/ggml-webgpu/CMakeLists.txt +++ b/ggml/src/ggml-webgpu/CMakeLists.txt @@ -39,8 +39,23 @@ add_dependencies(ggml-webgpu generate_shaders) if(EMSCRIPTEN) set(EMDAWNWEBGPU_DIR "" CACHE PATH "Path to emdawnwebgpu_pkg") - target_compile_options(ggml-webgpu PRIVATE "--use-port=${EMDAWNWEBGPU_DIR}/emdawnwebgpu.port.py") - target_link_options(ggml-webgpu PRIVATE "--use-port=${EMDAWNWEBGPU_DIR}/emdawnwebgpu.port.py") + if(NOT EMDAWNWEBGPU_DIR) + # default built-in port + target_compile_options(ggml-webgpu PRIVATE "--use-port=emdawnwebgpu") + target_link_options(ggml-webgpu INTERFACE "--use-port=emdawnwebgpu") + else() + # custom port + target_compile_options(ggml-webgpu PRIVATE "--use-port=${EMDAWNWEBGPU_DIR}/emdawnwebgpu.port.py") + target_link_options(ggml-webgpu INTERFACE "--use-port=${EMDAWNWEBGPU_DIR}/emdawnwebgpu.port.py") + endif() + + if (GGML_WEBGPU_JSPI) + target_compile_options(ggml-webgpu PRIVATE "-fwasm-exceptions") + target_link_options(ggml-webgpu INTERFACE "-sJSPI" "-fwasm-exceptions") + else() + target_compile_options(ggml-webgpu PRIVATE "-fexceptions") + target_link_options(ggml-webgpu INTERFACE "-sASYNCIFY" "-exceptions") + endif() else() find_package(Dawn REQUIRED) set(DawnWebGPU_TARGET dawn::webgpu_dawn) @@ -48,6 +63,9 @@ endif() if (GGML_WEBGPU_DEBUG) target_compile_definitions(ggml-webgpu PRIVATE GGML_WEBGPU_DEBUG=1) + if(EMSCRIPTEN) + target_link_options(ggml-webgpu INTERFACE "-sASSERTIONS=2") + endif() endif() if (GGML_WEBGPU_CPU_PROFILE) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 9e8cbc477ed..a7476b109df 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -9,6 +9,10 @@ #include "ggml-impl.h" #include "ggml-wgsl-shaders.hpp" +#ifdef __EMSCRIPTEN__ +# include +#endif + #include #include @@ -261,9 +265,12 @@ struct webgpu_context_struct { wgpu::Queue queue; wgpu::Limits limits; + uint32_t subgroup_size; + +#ifndef __EMSCRIPTEN__ bool supports_subgroup_matrix = false; - uint32_t subgroup_size; wgpu::SubgroupMatrixConfig subgroup_matrix_config; +#endif // Separate this out from limits since on some Metal systems, the limit returned by // querying the limits is higher than the actual allowed maximum. @@ -449,8 +456,8 @@ static void ggml_backend_webgpu_wait(webgpu_context & ct // If we have too many in-flight submissions, wait on the oldest one first. If there are many threads, // inflight_max may be 0, meaning that we must wait on all futures. uint64_t timeout_ms = block ? UINT64_MAX : 0; - uint inflight_threads = ctx->inflight_threads; - uint inflight_max = WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD / std::max(inflight_threads, 1u); + uint32_t inflight_threads = ctx->inflight_threads; + uint32_t inflight_max = WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD / std::max(inflight_threads, 1u); while (futures.size() >= inflight_max && futures.size() > 0) { ctx->instance.WaitAny(futures[0].futures.size(), futures[0].futures.data(), UINT64_MAX); futures.erase(futures.begin()); @@ -986,6 +993,7 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx, pipeline = ctx->mul_mat_pipelines[src0->type][src1->type][vectorized]; uint32_t wg_m; uint32_t wg_n; +#ifndef __EMSCRIPTEN__ if (ctx->supports_subgroup_matrix) { // The total number of subgroups/workgroups needed per matrix. uint32_t wg_m_sg_tile = @@ -995,11 +1003,15 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx, WEBGPU_MUL_MAT_SUBGROUP_N * WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N * ctx->subgroup_matrix_config.N; wg_n = (dst->ne[1] + wg_n_sg_tile - 1) / wg_n_sg_tile; } else { +#endif uint32_t tile_m_s = WEBGPU_MUL_MAT_TILE_M * WEBGPU_MUL_MAT_WG_SIZE_M; uint32_t tile_n_s = WEBGPU_MUL_MAT_TILE_N * WEBGPU_MUL_MAT_WG_SIZE_N; wg_m = (dst->ne[0] + tile_m_s - 1) / tile_m_s; wg_n = (dst->ne[1] + tile_n_s - 1) / tile_n_s; +#ifndef __EMSCRIPTEN__ } +#endif + wg_x = wg_m * wg_n * dst->ne[2] * dst->ne[3]; } } @@ -1419,9 +1431,9 @@ static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, str commands.push_back(*cmd); } // compute the batch size based on the number of inflight threads - uint inflight_threads = ctx->inflight_threads; - uint batch_size = std::min(std::max(1u, WEBGPU_NUM_PARAM_BUFS / std::max(inflight_threads, 1u)), - WEBGPU_COMMAND_SUBMIT_BATCH_SIZE); + uint32_t inflight_threads = ctx->inflight_threads; + uint32_t batch_size = std::min(std::max(1u, WEBGPU_NUM_PARAM_BUFS / std::max(inflight_threads, 1u)), + WEBGPU_COMMAND_SUBMIT_BATCH_SIZE); if (commands.size() >= batch_size) { futures.push_back(ggml_backend_webgpu_submit(ctx, commands)); // Process events and check for completed submissions @@ -1758,6 +1770,17 @@ static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context & webgpu_ctx) { ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ4_XS][GGML_TYPE_F32], wgsl_mul_mat_iq4_xs_f32, "mul_mat_iq4_xs_f32"); + std::string proc_mul_mat_f32_f32; + std::string proc_mul_mat_f32_f32_vec; + std::string proc_mul_mat_f16_f32; + std::string proc_mul_mat_f16_f32_vec; + std::string proc_mul_mat_f16_f16; + std::string proc_mul_mat_f16_f16_vec; + std::string proc_mul_mat_q4_0_f32; + std::string proc_mul_mat_q4_0_f32_vec; + + std::vector mul_mat_constants; +#ifndef __EMSCRIPTEN__ if (webgpu_ctx->supports_subgroup_matrix) { std::map sg_matrix_repls; sg_matrix_repls["WEBGPU_MAX_SUBGROUP_SIZE"] = std::to_string(webgpu_ctx->subgroup_size); @@ -1770,100 +1793,57 @@ static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context & webgpu_ctx) { sg_matrix_repls["WEBGPU_SG_MAT_N_SIZE"] = std::to_string(webgpu_ctx->subgroup_matrix_config.N); sg_matrix_repls["WEBGPU_SG_MAT_K_SIZE"] = std::to_string(webgpu_ctx->subgroup_matrix_config.K); - std::string proc_mul_mat_subgroup_matrix_f32_f32 = - ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f32_f32, sg_matrix_repls); - std::string proc_mul_mat_subgroup_matrix_f32_f32_vec = + proc_mul_mat_f32_f32 = ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f32_f32, sg_matrix_repls); + proc_mul_mat_f32_f32_vec = ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f32_f32_vec, sg_matrix_repls); - std::string proc_mul_mat_subgroup_matrix_f16_f32 = - ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f16_f32, sg_matrix_repls); - std::string proc_mul_mat_subgroup_matrix_f16_f32_vec = + proc_mul_mat_f16_f32 = ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f16_f32, sg_matrix_repls); + proc_mul_mat_f16_f32_vec = ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f16_f32_vec, sg_matrix_repls); - std::string proc_mul_mat_subgroup_matrix_f16_f16 = - ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f16_f16, sg_matrix_repls); - std::string proc_mul_mat_subgroup_matrix_f16_f16_vec = + proc_mul_mat_f16_f16 = ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f16_f16, sg_matrix_repls); + proc_mul_mat_f16_f16_vec = ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f16_f16_vec, sg_matrix_repls); - std::string proc_mul_mat_subgroup_matrix_q4_0_f32 = + proc_mul_mat_q4_0_f32 = ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_q4_0_f32, sg_matrix_repls); - std::string proc_mul_mat_subgroup_matrix_q4_0_f32_vec = + proc_mul_mat_q4_0_f32_vec = ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_q4_0_f32_vec, sg_matrix_repls); - - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline2( - webgpu_ctx->device, proc_mul_mat_subgroup_matrix_f32_f32.c_str(), "mul_mat_subgroup_matrix_f32_f32"); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][1] = - ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_subgroup_matrix_f32_f32_vec.c_str(), - "mul_mat_subgroup_matrix_f32_f32_vec"); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline2( - webgpu_ctx->device, proc_mul_mat_subgroup_matrix_f16_f32.c_str(), "mul_mat_subgroup_matrix_f16_f32"); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][1] = - ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_subgroup_matrix_f16_f32_vec.c_str(), - "mul_mat_subgroup_matrix_f16_f32_vec"); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][0] = ggml_webgpu_create_pipeline2( - webgpu_ctx->device, proc_mul_mat_subgroup_matrix_f16_f16.c_str(), "mul_mat_subgroup_matrix_f16_f16"); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][1] = - ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_subgroup_matrix_f16_f16_vec.c_str(), - "mul_mat_subgroup_matrix_f16_f16_vec"); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_0][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline2( - webgpu_ctx->device, proc_mul_mat_subgroup_matrix_q4_0_f32.c_str(), "mul_mat_subgroup_matrix_q4_0_f32"); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_0][GGML_TYPE_F32][1] = - ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_subgroup_matrix_q4_0_f32_vec.c_str(), - "mul_mat_subgroup_matrix_q4_0_f32_vec"); } else { - std::vector mul_mat_reg_tile_constants(3); - mul_mat_reg_tile_constants[0].key = "TILE_K"; - mul_mat_reg_tile_constants[0].value = WEBGPU_MUL_MAT_TILE_K; - mul_mat_reg_tile_constants[1].key = "WORKGROUP_SIZE_M"; - mul_mat_reg_tile_constants[1].value = WEBGPU_MUL_MAT_WG_SIZE_M; - mul_mat_reg_tile_constants[2].key = "WORKGROUP_SIZE_N"; - mul_mat_reg_tile_constants[2].value = WEBGPU_MUL_MAT_WG_SIZE_N; +#endif + mul_mat_constants.push_back({ .key = "TILE_K", .value = WEBGPU_MUL_MAT_TILE_K }); + mul_mat_constants.push_back({ .key = "WORKGROUP_SIZE_M", .value = WEBGPU_MUL_MAT_WG_SIZE_M }); + mul_mat_constants.push_back({ .key = "WORKGROUP_SIZE_N", .value = WEBGPU_MUL_MAT_WG_SIZE_N }); std::map reg_repls; reg_repls["WEBGPU_TILE_M"] = std::to_string(WEBGPU_MUL_MAT_TILE_M); reg_repls["WEBGPU_TILE_N"] = std::to_string(WEBGPU_MUL_MAT_TILE_N); - // Process each reg-tile shader with tile replacements. - // Keep the processed strings in-scope so .c_str() remains valid. - std::string proc_mul_mat_reg_tile_f32_f32 = - ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f32_f32, reg_repls); - std::string proc_mul_mat_reg_tile_f32_f32_vec = - ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f32_f32_vec, reg_repls); - std::string proc_mul_mat_reg_tile_f16_f32 = - ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f16_f32, reg_repls); - std::string proc_mul_mat_reg_tile_f16_f32_vec = - ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f16_f32_vec, reg_repls); - std::string proc_mul_mat_reg_tile_f16_f16 = - ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f16_f16, reg_repls); - std::string proc_mul_mat_reg_tile_f16_f16_vec = - ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f16_f16_vec, reg_repls); - std::string proc_mul_mat_reg_tile_q4_0_f32 = - ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_q4_0_f32, reg_repls); - std::string proc_mul_mat_reg_tile_q4_0_f32_vec = - ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_q4_0_f32_vec, reg_repls); - - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_reg_tile_f32_f32.c_str(), - "mul_mat_reg_tile_f32_f32", mul_mat_reg_tile_constants); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][1] = - ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_reg_tile_f32_f32_vec.c_str(), - "mul_mat_reg_tile_f32_f32_vec", mul_mat_reg_tile_constants); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_reg_tile_f16_f32.c_str(), - "mul_mat_reg_tile_f16_f32", mul_mat_reg_tile_constants); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][1] = - ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_reg_tile_f16_f32_vec.c_str(), - "mul_mat_reg_tile_f16_f32_vec", mul_mat_reg_tile_constants); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][0] = - ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_reg_tile_f16_f16.c_str(), - "mul_mat_reg_tile_f16_f16", mul_mat_reg_tile_constants); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][1] = - ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_reg_tile_f16_f16_vec.c_str(), - "mul_mat_reg_tile_f16_f16_vec", mul_mat_reg_tile_constants); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_0][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_reg_tile_q4_0_f32.c_str(), - "mul_mat_reg_tile_q4_0_f32", mul_mat_reg_tile_constants); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_0][GGML_TYPE_F32][1] = - ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_reg_tile_q4_0_f32_vec.c_str(), - "mul_mat_reg_tile_q4_0_f32_vec", mul_mat_reg_tile_constants); + proc_mul_mat_f32_f32 = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f32_f32, reg_repls); + proc_mul_mat_f32_f32_vec = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f32_f32_vec, reg_repls); + proc_mul_mat_f16_f32 = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f16_f32, reg_repls); + proc_mul_mat_f16_f32_vec = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f16_f32_vec, reg_repls); + proc_mul_mat_f16_f16 = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f16_f16, reg_repls); + proc_mul_mat_f16_f16_vec = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f16_f16_vec, reg_repls); + proc_mul_mat_q4_0_f32 = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_q4_0_f32, reg_repls); + proc_mul_mat_q4_0_f32_vec = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_q4_0_f32_vec, reg_repls); +#ifndef __EMSCRIPTEN__ } +#endif + + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline2( + webgpu_ctx->device, proc_mul_mat_f32_f32.c_str(), "mul_mat_f32_f32", mul_mat_constants); + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline2( + webgpu_ctx->device, proc_mul_mat_f32_f32_vec.c_str(), "mul_mat_f32_f32_vec", mul_mat_constants); + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline2( + webgpu_ctx->device, proc_mul_mat_f16_f32.c_str(), "mul_mat_f16_f32", mul_mat_constants); + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline2( + webgpu_ctx->device, proc_mul_mat_f16_f32_vec.c_str(), "mul_mat_f16_f32_vec", mul_mat_constants); + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][0] = ggml_webgpu_create_pipeline2( + webgpu_ctx->device, proc_mul_mat_f16_f16.c_str(), "mul_mat_f16_f16", mul_mat_constants); + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][1] = ggml_webgpu_create_pipeline2( + webgpu_ctx->device, proc_mul_mat_f16_f16_vec.c_str(), "mul_mat_f16_f16_vec", mul_mat_constants); + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_0][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline2( + webgpu_ctx->device, proc_mul_mat_q4_0_f32.c_str(), "mul_mat_q4_0_f32", mul_mat_constants); + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_0][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline2( + webgpu_ctx->device, proc_mul_mat_q4_0_f32_vec.c_str(), "mul_mat_q4_0_f32_vec", mul_mat_constants); std::vector mul_mat_vec_constants(3); mul_mat_vec_constants[0].key = "WORKGROUP_SIZE"; @@ -2384,13 +2364,17 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t webgpu_context ctx = reg_ctx->webgpu_ctx; + wgpu::RequestAdapterOptions options = {}; + +#ifndef __EMSCRIPTEN__ // TODO: track need for these toggles: https://issues.chromium.org/issues/42251215 const char * const adapterEnabledToggles[] = { "vulkan_enable_f16_on_nvidia", "use_vulkan_memory_model" }; wgpu::DawnTogglesDescriptor adapterTogglesDesc; adapterTogglesDesc.enabledToggles = adapterEnabledToggles; adapterTogglesDesc.enabledToggleCount = 2; - wgpu::RequestAdapterOptions options = {}; options.nextInChain = &adapterTogglesDesc; +#endif + ctx->instance.WaitAny(ctx->instance.RequestAdapter( &options, wgpu::CallbackMode::AllowSpontaneous, [&ctx](wgpu::RequestAdapterStatus status, wgpu::Adapter adapter, const char * message) { @@ -2406,11 +2390,13 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t ctx->adapter.GetLimits(&ctx->limits); ctx->max_wg_size_x = 288; // default value - wgpu::AdapterInfo info{}; + wgpu::AdapterInfo info{}; +#ifndef __EMSCRIPTEN__ wgpu::AdapterPropertiesSubgroupMatrixConfigs subgroup_matrix_configs{}; if (ctx->adapter.HasFeature(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix)) { info.nextInChain = &subgroup_matrix_configs; } +#endif ctx->adapter.GetInfo(&info); wgpu::SupportedFeatures features; @@ -2418,6 +2404,7 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t // we require f16 support GGML_ASSERT(ctx->adapter.HasFeature(wgpu::FeatureName::ShaderF16)); +#ifndef __EMSCRIPTEN__ // Only support square f16 matrices of size 8 or 16 for now bool valid_subgroup_matrix_config = false; if (ctx->adapter.HasFeature(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix)) { @@ -2433,36 +2420,27 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t } } + ctx->supports_subgroup_matrix = valid_subgroup_matrix_config; +#endif // For subgroup matrix code to be the most efficient, we would like the subgroup size to be consistent and accurate. // Unfortunately, that is not possible, so we use the maximum subgroup size reported by the adapter. - ctx->subgroup_size = info.subgroupMaxSize; - ctx->supports_subgroup_matrix = valid_subgroup_matrix_config; + ctx->subgroup_size = info.subgroupMaxSize; // Initialize device - std::vector required_features = { wgpu::FeatureName::ShaderF16, - wgpu::FeatureName::ImplicitDeviceSynchronization }; + std::vector required_features = { wgpu::FeatureName::ShaderF16 }; + +#ifndef __EMSCRIPTEN__ + required_features.push_back(wgpu::FeatureName::ImplicitDeviceSynchronization); if (ctx->supports_subgroup_matrix) { required_features.push_back(wgpu::FeatureName::Subgroups); required_features.push_back(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix); } +#endif #ifdef GGML_WEBGPU_GPU_PROFILE required_features.push_back(wgpu::FeatureName::TimestampQuery); #endif - // Enable Dawn-specific toggles to increase native performance - // TODO: Don't enable for WASM builds, they won't have an effect anyways - // TODO: Maybe WebGPU needs a "fast" mode where you can request compilers skip adding checks like these, - // only for native performance? - const char * const deviceEnabledToggles[] = { "skip_validation", "disable_robustness", "disable_workgroup_init", - "disable_polyfills_on_integer_div_and_mod" }; - const char * const deviceDisabledToggles[] = { "timestamp_quantization" }; - wgpu::DawnTogglesDescriptor deviceTogglesDesc; - deviceTogglesDesc.enabledToggles = deviceEnabledToggles; - deviceTogglesDesc.enabledToggleCount = 4; - deviceTogglesDesc.disabledToggles = deviceDisabledToggles; - deviceTogglesDesc.disabledToggleCount = 1; - wgpu::DeviceDescriptor dev_desc; dev_desc.requiredLimits = &ctx->limits; dev_desc.requiredFeatures = required_features.data(); @@ -2480,7 +2458,23 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t GGML_ABORT("ggml_webgpu: Device error! Reason: %d, Message: %s\n", static_cast(reason), std::string(message).c_str()); }); + +#ifndef __EMSCRIPTEN__ + // Enable Dawn-specific toggles to increase native performance + // TODO: Maybe WebGPU needs a "fast" mode where you can request compilers skip adding checks like these, + // only for native performance? + const char * const deviceEnabledToggles[] = { "skip_validation", "disable_robustness", "disable_workgroup_init", + "disable_polyfills_on_integer_div_and_mod" }; + const char * const deviceDisabledToggles[] = { "timestamp_quantization" }; + wgpu::DawnTogglesDescriptor deviceTogglesDesc; + deviceTogglesDesc.enabledToggles = deviceEnabledToggles; + deviceTogglesDesc.enabledToggleCount = 4; + deviceTogglesDesc.disabledToggles = deviceDisabledToggles; + deviceTogglesDesc.disabledToggleCount = 1; + dev_desc.nextInChain = &deviceTogglesDesc; +#endif + ctx->instance.WaitAny(ctx->adapter.RequestDevice( &dev_desc, wgpu::CallbackMode::AllowSpontaneous, [ctx](wgpu::RequestDeviceStatus status, wgpu::Device device, wgpu::StringView message) { @@ -2578,18 +2572,27 @@ ggml_backend_reg_t ggml_backend_webgpu_reg() { ctx.name = GGML_WEBGPU_NAME; ctx.device_count = 1; - const char * const instanceEnabledToggles[] = { "allow_unsafe_apis" }; - - wgpu::DawnTogglesDescriptor instanceTogglesDesc; - instanceTogglesDesc.enabledToggles = instanceEnabledToggles; - instanceTogglesDesc.enabledToggleCount = 1; wgpu::InstanceDescriptor instance_descriptor{}; std::vector instance_features = { wgpu::InstanceFeatureName::TimedWaitAny }; instance_descriptor.requiredFeatures = instance_features.data(); instance_descriptor.requiredFeatureCount = instance_features.size(); - instance_descriptor.nextInChain = &instanceTogglesDesc; + +#ifndef __EMSCRIPTEN__ + const char * const instanceEnabledToggles[] = { "allow_unsafe_apis" }; + wgpu::DawnTogglesDescriptor instanceTogglesDesc; + instanceTogglesDesc.enabledToggles = instanceEnabledToggles; + instanceTogglesDesc.enabledToggleCount = 1; + instance_descriptor.nextInChain = &instanceTogglesDesc; +#endif webgpu_ctx->instance = wgpu::CreateInstance(&instance_descriptor); + +#ifdef __EMSCRIPTEN__ + if (webgpu_ctx->instance == nullptr) { + GGML_LOG_ERROR("ggml_webgpu: Failed to create WebGPU instance. Make sure either -sASYNCIFY or -sJSPI is set\n"); + return nullptr; + } +#endif GGML_ASSERT(webgpu_ctx->instance != nullptr); static ggml_backend_reg reg = { diff --git a/ggml/src/gguf.cpp b/ggml/src/gguf.cpp index 8cc4ef1cf44..b165d8bdc62 100644 --- a/ggml/src/gguf.cpp +++ b/ggml/src/gguf.cpp @@ -1169,7 +1169,7 @@ void gguf_set_tensor_data(struct gguf_context * ctx, const char * name, const vo struct gguf_writer_base { size_t written_bytes {0u}; - ~gguf_writer_base(void) {} + ~gguf_writer_base(void) = default; // we bet on devirtualization virtual void write(int8_t val) = 0; diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 2b8489c591b..31ff726b075 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -294,6 +294,12 @@ class ClipVision: USE_SILU = "clip.use_silu" N_WA_PATTERN = "clip.vision.n_wa_pattern" # used by qwen2.5vl IS_DEEPSTACK_LAYERS = "clip.vision.is_deepstack_layers" + NUM_IMAGE_TOKENS = "clip.vision.num_img_tokens" # phi3v + USE_HD_TRANSFORM = "clip.vision.use_hd_transform" # phi3v + WITH_LEARNABLE_SEPERATOR = "clip.vision.with_learnable_separator" # phi3v + HD_TRANSFORM_ORDER = "clip.vision.hd_transform_order" # phi3v + NUM_CROPS = "clip.vision.num_crops" # phi3v + IMAGE_DIM_OUT = "clip.vision.image_dim_out" # phi3v class Attention: HEAD_COUNT = "clip.vision.attention.head_count" @@ -458,6 +464,7 @@ class VISION_PROJECTOR_TYPE(IntEnum): GEMMA3 = auto() QWEN3VL = auto() COGVLM = auto() + PHI3V = auto() class MODEL_TENSOR(IntEnum): @@ -685,6 +692,8 @@ class MODEL_TENSOR(IntEnum): V_MM_GATE = auto() # cogvlm V_TOK_BOI = auto() # cogvlm V_TOK_EOI = auto() # cogvlm + V_ENC_GLB_GN = auto() # phi3v + V_ENC_SUB_GN = auto() # phi3v # audio (mtmd) A_ENC_EMBD_POS = auto() A_ENC_CONV1D = auto() @@ -830,6 +839,7 @@ class MODEL_TENSOR(IntEnum): VISION_PROJECTOR_TYPE.GLM_EDGE: "adapter", VISION_PROJECTOR_TYPE.MERGER: "qwen2vl_merger", VISION_PROJECTOR_TYPE.GEMMA3: "gemma3", + VISION_PROJECTOR_TYPE.PHI3V: "phi3_v" } TENSOR_NAMES: dict[MODEL_TENSOR, str] = { @@ -1057,6 +1067,8 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.V_MM_GATE: "mm.gate", MODEL_TENSOR.V_TOK_BOI: "v.boi", MODEL_TENSOR.V_TOK_EOI: "v.eoi", + MODEL_TENSOR.V_ENC_GLB_GN: "v.glb_GN", # phi3v + MODEL_TENSOR.V_ENC_SUB_GN: "v.sub_GN", # phi3v # audio (mtmd) MODEL_TENSOR.A_ENC_EMBD_POS: "a.position_embd", MODEL_TENSOR.A_ENC_CONV1D: "a.conv1d.{bid}", @@ -1135,6 +1147,8 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.V_MM_GATE, MODEL_TENSOR.V_TOK_BOI, MODEL_TENSOR.V_TOK_EOI, + MODEL_TENSOR.V_ENC_GLB_GN, # Specific to Phi3V + MODEL_TENSOR.V_ENC_SUB_GN, # Specific to Phi3V # audio MODEL_TENSOR.A_ENC_EMBD_POS, MODEL_TENSOR.A_ENC_CONV1D, @@ -3327,6 +3341,7 @@ class VisionProjectorType: LIGHTONOCR = "lightonocr" COGVLM = "cogvlm" JANUS_PRO = "janus_pro" + PHI3V = "phi3_v" # Items here are (block size, type size) diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index 9e6ff3ac777..451f5a59332 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -1070,6 +1070,9 @@ def add_classifier_output_labels(self, labels: Sequence[str]) -> None: def add_clip_has_vision_encoder(self, value: bool) -> None: self.add_bool(Keys.Clip.HAS_VISION_ENCODER, value) + def add_clip_has_llava_projector(self, value: bool) -> None: + self.add_bool(Keys.Clip.HAS_LLAVA_PROJECTOR, value) + def add_clip_has_audio_encoder(self, value: bool) -> None: self.add_bool(Keys.Clip.HAS_AUDIO_ENCODER, value) diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index a7b09739791..26acd9f65fe 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -1201,6 +1201,7 @@ class TensorNameMap: "vision_model.vision_adapter.mlp.fc{bid}", # llama 4 "mlp1.{bid}", # InternVL "model.aligner.fc1.hidden_layers.{bid}", # Janus Pro + "model.vision_embed_tokens.img_projection.{bid}", # phi3v ), MODEL_TENSOR.V_MMPROJ_PEG: ( @@ -1212,6 +1213,7 @@ class TensorNameMap: "model.vision_tower.embeddings.cls_token", # Intern-S1 "vision_model.class_embedding", # llama 4 "model.vision.patch_embedding.cls_embedding", # cogvlm + "model.vision_embed_tokens.img_processor.vision_model.embeddings.class_embedding", # phi3v ), MODEL_TENSOR.V_ENC_EMBD_PATCH: ( @@ -1225,6 +1227,7 @@ class TensorNameMap: "visual.patch_embed.proj", # qwen2vl "vision_tower.patch_embed.proj", # kimi-vl "model.vision.patch_embedding.proj", # cogvlm + "model.vision_embed_tokens.img_processor.vision_model.embeddings.patch_embedding", # phi3v ), MODEL_TENSOR.V_ENC_EMBD_POS: ( @@ -1236,6 +1239,7 @@ class TensorNameMap: "vision_tower.patch_embed.pos_emb", # kimi-vl "visual.pos_embed", # qwen3vl "model.vision.patch_embedding.position_embedding", # cogvlm + "model.vision_embed_tokens.img_processor.vision_model.embeddings.position_embedding", # phi3v ), MODEL_TENSOR.V_ENC_ATTN_QKV: ( @@ -1253,6 +1257,7 @@ class TensorNameMap: "vision_encoder.transformer.layers.{bid}.attention.wq", # pixtral "visual.blocks.{bid}.attn.q", # qwen2vl, generated "vision_tower.encoder.blocks.{bid}.wq", # kimi-vl, generated + "model.vision_embed_tokens.img_processor.vision_model.encoder.layers.{bid}.self_attn.q_proj", # phi3v ), MODEL_TENSOR.V_ENC_ATTN_Q_NORM: ( @@ -1270,6 +1275,7 @@ class TensorNameMap: "vision_encoder.transformer.layers.{bid}.attention.wk", # pixtral "visual.blocks.{bid}.attn.k", # qwen2vl, generated "vision_tower.encoder.blocks.{bid}.wk", # kimi-vl, generated + "model.vision_embed_tokens.img_processor.vision_model.encoder.layers.{bid}.self_attn.k_proj", # phi3v ), MODEL_TENSOR.V_ENC_ATTN_K_NORM: ( @@ -1287,6 +1293,7 @@ class TensorNameMap: "vision_encoder.transformer.layers.{bid}.attention.wv", # pixtral "visual.blocks.{bid}.attn.v", # qwen2vl, generated "vision_tower.encoder.blocks.{bid}.wv", # kimi-vl, generated + "model.vision_embed_tokens.img_processor.vision_model.encoder.layers.{bid}.self_attn.v_proj", # phi3v ), MODEL_TENSOR.V_ENC_INPUT_NORM: ( @@ -1301,6 +1308,7 @@ class TensorNameMap: "visual.blocks.{bid}.norm1", # qwen2vl "vision_tower.encoder.blocks.{bid}.norm0", # kimi-vl (norm0/norm1) "model.vision.transformer.layers.{bid}.input_layernorm", # cogvlm + "model.vision_embed_tokens.img_processor.vision_model.encoder.layers.{bid}.layer_norm1", # phi3v ), MODEL_TENSOR.V_ENC_ATTN_O: ( @@ -1316,6 +1324,7 @@ class TensorNameMap: "visual.blocks.{bid}.attn.proj", # qwen2vl "vision_tower.encoder.blocks.{bid}.wo", # kimi-vl "model.vision.transformer.layers.{bid}.attention.dense", # cogvlm + "model.vision_embed_tokens.img_processor.vision_model.encoder.layers.{bid}.self_attn.out_proj", # phi3v ), MODEL_TENSOR.V_ENC_POST_ATTN_NORM: ( @@ -1330,6 +1339,7 @@ class TensorNameMap: "visual.blocks.{bid}.norm2", # qwen2vl "vision_tower.encoder.blocks.{bid}.norm1", # kimi-vl (norm0/norm1) "model.vision.transformer.layers.{bid}.post_attention_layernorm", # cogvlm + "model.vision_embed_tokens.img_processor.vision_model.encoder.layers.{bid}.layer_norm2", # phi3v ), MODEL_TENSOR.V_ENC_FFN_UP: ( @@ -1345,6 +1355,7 @@ class TensorNameMap: "visual.blocks.{bid}.mlp.linear_fc1", # qwen3vl "vision_tower.encoder.blocks.{bid}.mlp.fc0", # kimi-vl (fc0/fc1) "model.vision.transformer.layers.{bid}.mlp.fc1", # cogvlm + "model.vision_embed_tokens.img_processor.vision_model.encoder.layers.{bid}.mlp.fc1", # phi3v ), MODEL_TENSOR.V_ENC_FFN_GATE: ( @@ -1366,6 +1377,7 @@ class TensorNameMap: "visual.blocks.{bid}.mlp.linear_fc2", # qwen3vl "vision_tower.encoder.blocks.{bid}.mlp.fc1", # kimi-vl (fc0/fc1) "model.vision.transformer.layers.{bid}.mlp.fc2", # cogvlm + "model.vision_embed_tokens.img_processor.vision_model.encoder.layers.{bid}.mlp.fc2", # phi3v ), MODEL_TENSOR.V_LAYER_SCALE_1: ( @@ -1383,6 +1395,7 @@ class TensorNameMap: "vision_tower.ln_pre", # pixtral-hf "vision_encoder.ln_pre", # pixtral "vision_model.layernorm_pre", # llama4 + "model.vision_embed_tokens.img_processor.vision_model.pre_layrnorm", # phi3v ), MODEL_TENSOR.V_POST_NORM: ( @@ -1391,6 +1404,7 @@ class TensorNameMap: "vision_model.layernorm_post", # llama4 "visual.merger.ln_q", # qwen2vl "vision_tower.encoder.final_layernorm", # kimi-vl + "model.vision_embed_tokens.img_processor.vision_model.post_layernorm", # phi3v ), MODEL_TENSOR.V_MM_INP_PROJ: ( @@ -1593,6 +1607,14 @@ class TensorNameMap: MODEL_TENSOR.NEXTN_SHARED_HEAD_NORM: ( "model.layers.{bid}.shared_head.norm", ), + + MODEL_TENSOR.V_ENC_GLB_GN: ( + "model.vision_embed_tokens.glb_GN", + ), + + MODEL_TENSOR.V_ENC_SUB_GN: ( + "model.vision_embed_tokens.sub_GN", + ), } # architecture-specific block mappings diff --git a/gguf-py/gguf/vocab.py b/gguf-py/gguf/vocab.py index 5c6817109ba..028e5748e4d 100644 --- a/gguf-py/gguf/vocab.py +++ b/gguf-py/gguf/vocab.py @@ -31,6 +31,14 @@ else: _mistral_common_installed = True +try: + from mistral_common.tokens.tokenizers.utils import ( # pyright: ignore[reportMissingImports] + get_one_valid_tokenizer_file, + ) +except ImportError: + # We still want the conversion to work with older mistral-common versions. + get_one_valid_tokenizer_file = None + import gguf @@ -673,24 +681,30 @@ def __init__(self, base_path: Path): # Find the tokenizer files all_files = [f.as_posix() for f in base_path.glob("**/*") if f.is_file()] - valid_tokenizer_files = _filter_valid_tokenizer_files(all_files) - - if len(valid_tokenizer_files) == 0: - raise ValueError(f"No tokenizer file found in the directory: {base_path}") - # If there are multiple tokenizer files, we use tekken.json if it exists, otherwise the versioned one. - if len(valid_tokenizer_files) > 1: - if "tekken.json" in valid_tokenizer_files: - tokenizer_file = "tekken.json" - else: - tokenizer_file = sorted(valid_tokenizer_files)[-1] - logger.warning( - f"Multiple tokenizer files found in {base_path}. Using {tokenizer_file}" - ) + + if get_one_valid_tokenizer_file is not None: + tokenizer_file_path = get_one_valid_tokenizer_file(all_files) else: - tokenizer_file = valid_tokenizer_files[0] + valid_tokenizer_files = _filter_valid_tokenizer_files(all_files) + + if len(valid_tokenizer_files) == 0: + raise ValueError(f"No tokenizer file found in the directory: {base_path}") + # If there are multiple tokenizer files, we use tekken.json if it exists, otherwise the versioned one. + if len(valid_tokenizer_files) > 1: + if "tekken.json" in valid_tokenizer_files: + tokenizer_file = "tekken.json" + else: + tokenizer_file = sorted(valid_tokenizer_files)[-1] + logger.warning( + f"Multiple tokenizer files found in {base_path}. Using {tokenizer_file}" + ) + else: + tokenizer_file = valid_tokenizer_files[0] + + tokenizer_file_path = base_path / tokenizer_file self.tokenizer = MistralTokenizer.from_file( - base_path / tokenizer_file + tokenizer_file_path ).instruct_tokenizer.tokenizer self.tokenizer_type = ( MistralTokenizerType.tekken @@ -698,7 +712,7 @@ def __init__(self, base_path: Path): else MistralTokenizerType.spm ) self.vocab_size = self.tokenizer.n_words - self.fname_tokenizer = base_path / tokenizer_file + self.fname_tokenizer = tokenizer_file_path self._name = ( "mistral-" + self.tokenizer_type.value + "-" + self.tokenizer.version ) diff --git a/scripts/serve-static.js b/scripts/serve-static.js new file mode 100644 index 00000000000..8ddc04aad98 --- /dev/null +++ b/scripts/serve-static.js @@ -0,0 +1,110 @@ +const http = require('http'); +const fs = require('fs').promises; +const path = require('path'); + +// This file is used for testing wasm build from emscripten +// Example build command: +// emcmake cmake -B build-wasm -DGGML_WEBGPU=ON -DLLAMA_CURL=OFF +// cmake --build build-wasm --target test-backend-ops -j + +const PORT = 8080; +const STATIC_DIR = path.join(__dirname, '../build-wasm/bin'); +console.log(`Serving static files from: ${STATIC_DIR}`); + +const mimeTypes = { + '.html': 'text/html', + '.js': 'text/javascript', + '.css': 'text/css', + '.png': 'image/png', + '.jpg': 'image/jpeg', + '.gif': 'image/gif', + '.svg': 'image/svg+xml', + '.json': 'application/json', + '.woff': 'font/woff', + '.woff2': 'font/woff2', +}; + +async function generateDirListing(dirPath, reqUrl) { + const files = await fs.readdir(dirPath); + let html = ` + + + + Directory Listing + + + +

Directory: ${reqUrl}

+
    + `; + + if (reqUrl !== '/') { + html += `
  • ../ (Parent Directory)
  • `; + } + + for (const file of files) { + const filePath = path.join(dirPath, file); + const stats = await fs.stat(filePath); + const link = encodeURIComponent(file) + (stats.isDirectory() ? '/' : ''); + html += `
  • ${file}${stats.isDirectory() ? '/' : ''}
  • `; + } + + html += ` +
+ + + `; + return html; +} + +const server = http.createServer(async (req, res) => { + try { + // Set COOP and COEP headers + res.setHeader('Cross-Origin-Opener-Policy', 'same-origin'); + res.setHeader('Cross-Origin-Embedder-Policy', 'require-corp'); + res.setHeader('Cache-Control', 'no-store, no-cache, must-revalidate, proxy-revalidate'); + res.setHeader('Pragma', 'no-cache'); + res.setHeader('Expires', '0'); + + const filePath = path.join(STATIC_DIR, decodeURIComponent(req.url)); + const stats = await fs.stat(filePath); + + if (stats.isDirectory()) { + const indexPath = path.join(filePath, 'index.html'); + try { + const indexData = await fs.readFile(indexPath); + res.writeHeader(200, { 'Content-Type': 'text/html' }); + res.end(indexData); + } catch { + // No index.html, generate directory listing + const dirListing = await generateDirListing(filePath, req.url); + res.writeHeader(200, { 'Content-Type': 'text/html' }); + res.end(dirListing); + } + } else { + const ext = path.extname(filePath).toLowerCase(); + const contentType = mimeTypes[ext] || 'application/octet-stream'; + const data = await fs.readFile(filePath); + res.writeHeader(200, { 'Content-Type': contentType }); + res.end(data); + } + } catch (err) { + if (err.code === 'ENOENT') { + res.writeHeader(404, { 'Content-Type': 'text/plain' }); + res.end('404 Not Found'); + } else { + res.writeHeader(500, { 'Content-Type': 'text/plain' }); + res.end('500 Internal Server Error'); + } + } +}); + +server.listen(PORT, () => { + console.log(`Server running at http://localhost:${PORT}/`); +}); diff --git a/src/llama-impl.h b/src/llama-impl.h index c5163e9225a..c3391e79f51 100644 --- a/src/llama-impl.h +++ b/src/llama-impl.h @@ -37,7 +37,7 @@ void llama_log_callback_default(ggml_log_level level, const char * text, void * template struct no_init { T value; - no_init() { /* do nothing */ } + no_init() = default; }; struct time_meas { diff --git a/src/llama-model.cpp b/src/llama-model.cpp index e281dc760bd..c3675dbdc41 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -423,8 +423,8 @@ static buft_list_t make_gpu_buft_list(ggml_backend_dev_t dev, llama_split_mode s } struct llama_model::impl { - impl() {} - ~impl() {} + impl() = default; + ~impl() = default; uint64_t n_elements = 0; @@ -461,7 +461,7 @@ llama_model::llama_model(const llama_model_params & params) : params(params), pi pimpl->has_tensor_overrides = params.tensor_buft_overrides && params.tensor_buft_overrides[0].pattern; } -llama_model::~llama_model() {} +llama_model::~llama_model() = default; void llama_model::load_stats(llama_model_loader & ml) { pimpl->n_elements = ml.n_elements; diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp index a73c4c448ba..e2cca66e48f 100644 --- a/src/llama-vocab.cpp +++ b/src/llama-vocab.cpp @@ -3253,8 +3253,7 @@ void llama_vocab::impl::print_info() const { llama_vocab::llama_vocab() : pimpl(new impl(*this)) { } -llama_vocab::~llama_vocab() { -} +llama_vocab::~llama_vocab() = default; void llama_vocab::load(llama_model_loader & ml, const LLM_KV & kv) { pimpl->load(ml, kv); diff --git a/tests/.gitignore b/tests/.gitignore index cbc381606cb..ba2b164fac5 100644 --- a/tests/.gitignore +++ b/tests/.gitignore @@ -3,3 +3,4 @@ *.o ggml-common.h **/*.swp +!peg-parser diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 9361a113a19..9ba559c8dfb 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -1,13 +1,15 @@ llama_add_compile_flags() function(llama_build source) + set(TEST_SOURCES ${source} ${ARGN}) + if (DEFINED LLAMA_TEST_NAME) set(TEST_TARGET ${LLAMA_TEST_NAME}) else() get_filename_component(TEST_TARGET ${source} NAME_WE) endif() - add_executable(${TEST_TARGET} ${source}) + add_executable(${TEST_TARGET} ${TEST_SOURCES}) target_link_libraries(${TEST_TARGET} PRIVATE common) install(TARGETS ${TEST_TARGET} RUNTIME) endfunction() @@ -83,6 +85,8 @@ function(llama_build_and_test source) set(multiValueArgs ARGS) cmake_parse_arguments(LLAMA_TEST "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) + set(TEST_SOURCES ${source} ${LLAMA_TEST_UNPARSED_ARGUMENTS} get-model.cpp) + if (NOT DEFINED LLAMA_TEST_LABEL) set(LLAMA_TEST_LABEL "main") endif() @@ -95,7 +99,7 @@ function(llama_build_and_test source) get_filename_component(TEST_TARGET ${source} NAME_WE) endif() - add_executable(${TEST_TARGET} ${source} get-model.cpp) + add_executable(${TEST_TARGET} ${TEST_SOURCES}) install(TARGETS ${TEST_TARGET} RUNTIME) target_link_libraries(${TEST_TARGET} PRIVATE common) @@ -180,9 +184,21 @@ if (NOT WIN32 OR NOT BUILD_SHARED_LIBS) endif() llama_build_and_test(test-chat-parser.cpp) +llama_build_and_test(test-chat-peg-parser.cpp peg-parser/simple-tokenize.cpp) llama_build_and_test(test-chat-template.cpp) llama_build_and_test(test-json-partial.cpp) llama_build_and_test(test-log.cpp) +llama_build_and_test( + test-peg-parser.cpp + peg-parser/simple-tokenize.cpp + peg-parser/test-basic.cpp + peg-parser/test-gbnf-generation.cpp + peg-parser/test-json-parser.cpp + peg-parser/test-json-serialization.cpp + peg-parser/test-unicode.cpp + peg-parser/testing.h + peg-parser/tests.h +) llama_build_and_test(test-regex-partial.cpp) if (NOT ${CMAKE_SYSTEM_PROCESSOR} MATCHES "s390x") diff --git a/tests/peg-parser/simple-tokenize.cpp b/tests/peg-parser/simple-tokenize.cpp new file mode 100644 index 00000000000..9abfa0448fa --- /dev/null +++ b/tests/peg-parser/simple-tokenize.cpp @@ -0,0 +1,37 @@ +#include "simple-tokenize.h" + +std::vector simple_tokenize(const std::string & input) { + std::vector result; + std::string current; + + for (size_t i = 0; i < input.size(); i++) { + switch (input[i]) { + case ' ': + case '\n': + case '\t': + case '{': + case '}': + case ',': + case '[': + case '"': + case ']': + case '.': + case '<': + case '>': + case '=': + case '/': + if (!current.empty()) { + result.push_back(current); + current.clear(); + } + default:; + } + current += input[i]; + } + + if (!current.empty()) { + result.push_back(current); + } + + return result; +} diff --git a/tests/peg-parser/simple-tokenize.h b/tests/peg-parser/simple-tokenize.h new file mode 100644 index 00000000000..1772432c5aa --- /dev/null +++ b/tests/peg-parser/simple-tokenize.h @@ -0,0 +1,6 @@ +#pragma once + +#include +#include + +std::vector simple_tokenize(const std::string &); diff --git a/tests/peg-parser/test-basic.cpp b/tests/peg-parser/test-basic.cpp new file mode 100644 index 00000000000..1bda6f2e690 --- /dev/null +++ b/tests/peg-parser/test-basic.cpp @@ -0,0 +1,454 @@ +#include "tests.h" + +void test_basic(testing & t) { + t.test("chars", [](testing & t) { + // Test common escape sequences - newline + t.test("escape_sequence_newline", [](testing &t) { + auto common_chat_combinator_parser = build_peg_parser([](common_peg_parser_builder & p) { return p.chars("[\\n\\t\\\\]"); }); + + common_peg_parse_context ctx; + common_peg_parse_result result; + + ctx = common_peg_parse_context("\n"); + result = common_chat_combinator_parser.parse(ctx); + t.assert_equal("escape_sequence_newline", true, result.success()); + }); + + // Test common escape sequences - tab + t.test("escape_sequence_tab", [](testing &t) { + auto common_chat_combinator_parser = build_peg_parser([](common_peg_parser_builder & p) { return p.chars("[\\n\\t\\\\]"); }); + + common_peg_parse_context ctx; + common_peg_parse_result result; + + ctx = common_peg_parse_context("\t"); + result = common_chat_combinator_parser.parse(ctx); + t.assert_equal("escape_sequence_tab", true, result.success()); + }); + + // Test common escape sequences - backslash + t.test("escape_sequence_backslash", [](testing &t) { + auto common_chat_combinator_parser = build_peg_parser([](common_peg_parser_builder & p) { return p.chars("[\\n\\t\\\\]"); }); + + common_peg_parse_context ctx; + common_peg_parse_result result; + + ctx = common_peg_parse_context("\\"); + result = common_chat_combinator_parser.parse(ctx); + t.assert_equal("escape_sequence_backslash", true, result.success()); + }); + + // Test common escape sequences - space (should ()) + t.test("escape_sequence_space_fail", [](testing &t) { + auto common_chat_combinator_parser = build_peg_parser([](common_peg_parser_builder & p) { return p.chars("[\\n\\t\\\\]"); }); + + common_peg_parse_context ctx; + common_peg_parse_result result; + + ctx = common_peg_parse_context(" "); + result = common_chat_combinator_parser.parse(ctx); + t.assert_equal("escape_sequence_space_fail", true, result.fail()); + }); + + // Test escaped dash - 'a' should succeed + t.test("escaped_dash_a", [](testing &t) { + auto common_chat_combinator_parser = build_peg_parser([](common_peg_parser_builder & p) { return p.chars("[a\\-z]"); }); + + common_peg_parse_context ctx; + common_peg_parse_result result; + + ctx = common_peg_parse_context("a"); + result = common_chat_combinator_parser.parse(ctx); + t.assert_equal("escaped_dash_a", true, result.success()); + }); + + // Test escaped dash - '-' should succeed (literal dash) + t.test("escaped_dash_literal", [](testing &t) { + auto common_chat_combinator_parser = build_peg_parser([](common_peg_parser_builder & p) { return p.chars("[a\\-z]"); }); + + common_peg_parse_context ctx; + common_peg_parse_result result; + + ctx = common_peg_parse_context("-"); + result = common_chat_combinator_parser.parse(ctx); + t.assert_equal("escaped_dash_literal", true, result.success()); + }); + + // Test escaped dash - 'z' should succeed + t.test("escaped_dash_z", [](testing &t) { + auto common_chat_combinator_parser = build_peg_parser([](common_peg_parser_builder & p) { return p.chars("[a\\-z]"); }); + + common_peg_parse_context ctx; + common_peg_parse_result result; + + ctx = common_peg_parse_context("z"); + result = common_chat_combinator_parser.parse(ctx); + t.assert_equal("escaped_dash_z", true, result.success()); + }); + + // Test escaped dash - 'b' should NOT match (since \- is literal dash, not range) + t.test("escaped_dash_b_fail", [](testing &t) { + auto common_chat_combinator_parser = build_peg_parser([](common_peg_parser_builder & p) { return p.chars("[a\\-z]"); }); + + common_peg_parse_context ctx; + common_peg_parse_result result; + + ctx = common_peg_parse_context("b"); + result = common_chat_combinator_parser.parse(ctx); + t.assert_equal("escaped_dash_b_fail", true, result.fail()); + }); + }); + + + t.test("optional", [](testing & t) { + // Full match with optional part present + t.test("optional_present", [](testing &t) { + auto parser = build_peg_parser([](common_peg_parser_builder & p) { + return p.literal("hello") + p.optional(p.literal(" world")); + }); + + auto ctx = common_peg_parse_context("hello world"); + auto result = parser.parse(ctx); + t.assert_equal("optional_present", true, result.success()); + t.assert_equal("optional_present_end", 11u, result.end); + }); + + // Full match with optional part absent + t.test("optional_absent", [](testing &t) { + auto parser = build_peg_parser([](common_peg_parser_builder & p) { + return p.literal("hello") + p.optional(p.literal(" world")); + }); + + auto ctx = common_peg_parse_context("hello", false); + auto result = parser.parse(ctx); + t.assert_equal("optional_absent", true, result.success()); + t.assert_equal("optional_absent_end", 5u, result.end); + }); + + // Partial match - waiting for more input to determine if optional matches + t.test("partial_match_need_more", [](testing &t) { + auto parser = build_peg_parser([](common_peg_parser_builder & p) { + return p.literal("hello") + p.optional(p.literal(" world")); + }); + + auto ctx = common_peg_parse_context("hello ", true); + auto result = parser.parse(ctx); + t.assert_equal("partial_match_need_more", true, result.need_more_input()); + }); + }); + + t.test("partial parsing", [](testing & t) { + // Literals - Basic Success + t.test("literal_success", [&](testing & t) { + auto parser = build_peg_parser([](common_peg_parser_builder & p) { return p.literal("hello"); }); + + common_peg_parse_context ctx; + common_peg_parse_result result; + + ctx = common_peg_parse_context("hello"); + result = parser.parse(ctx); + t.assert_equal("literal_success", true, result.success()); + }); + + // Char Classes - Basic Lowercase Success + t.test("char_class_lowercase_success", [&](testing & t) { + auto parser = build_peg_parser([](common_peg_parser_builder & p) { return p.chars("a-z"); }); + + common_peg_parse_context ctx; + common_peg_parse_result result; + + ctx = common_peg_parse_context("a"); + result = parser.parse(ctx); + t.assert_equal("char_class_lowercase_success", true, result.success()); + }); + + // Char Classes - Uppercase Fail + t.test("char_class_uppercase_fail", [&](testing & t) { + auto parser = build_peg_parser([](common_peg_parser_builder & p) { return p.chars("a-z"); }); + + common_peg_parse_context ctx; + common_peg_parse_result result; + + ctx = common_peg_parse_context("A"); + result = parser.parse(ctx); + t.assert_equal("char_class_uppercase_fail", true, result.fail()); + }); + + // Char Classes with Dash - Lowercase Success + t.test("char_class_with_dash_lowercase", [&](testing & t) { + auto parser = build_peg_parser([](common_peg_parser_builder & p) { return p.chars("a-z-"); }); + + common_peg_parse_context ctx; + common_peg_parse_result result; + + ctx = common_peg_parse_context("f"); + result = parser.parse(ctx); + t.assert_equal("char_class_with_dash_lowercase", true, result.success()); + }); + + // Char Classes with Dash - Literal Dash Success + t.test("char_class_with_dash_literal_dash", [&](testing & t) { + auto parser = build_peg_parser([](common_peg_parser_builder & p) { return p.chars("a-z-"); }); + + common_peg_parse_context ctx; + common_peg_parse_result result; + + ctx = common_peg_parse_context("-"); + result = parser.parse(ctx); + t.assert_equal("char_class_with_dash_literal_dash", true, result.success()); + }); + + // Char Classes with Dash - Uppercase Fail + t.test("char_class_with_dash_uppercase_fail", [&](testing & t) { + auto parser = build_peg_parser([](common_peg_parser_builder & p) { return p.chars("a-z-"); }); + + common_peg_parse_context ctx; + common_peg_parse_result result; + + ctx = common_peg_parse_context("A"); + result = parser.parse(ctx); + t.assert_equal("char_class_with_dash_uppercase_fail", true, result.fail()); + }); + + // Sequences - Partial Match 1 + t.test("sequence_partial_match_1", [&](testing & t) { + auto parser = build_peg_parser([](common_peg_parser_builder & p) { return p.literal("") + p.literal(""); }); + + auto ctx = common_peg_parse_context("") + p.literal(""); }); + + auto ctx = common_peg_parse_context("") + p.literal(""); }); + + auto ctx = common_peg_parse_context("I am common_chat_combinator_parser", true); + auto result = parser.parse(ctx); + t.assert_equal("sequence_no_match", true, result.fail()); + }); + + // Choices - Partial Match 1 + t.test("choices_partial_match_1", [&](testing & t) { + auto parser = build_peg_parser([](common_peg_parser_builder & p) { return p.literal("option1") | p.literal("option2"); }); + + auto ctx = common_peg_parse_context("opt", true); + auto result = parser.parse(ctx); + t.assert_equal("choices_partial_match_1", true, result.need_more_input()); + }); + + // Choices - Partial Match 2 + t.test("choices_partial_match_2", [&](testing & t) { + auto parser = + build_peg_parser([](common_peg_parser_builder & p) { return p.literal("choice_a") | p.literal("choice_b"); }); + + auto ctx = common_peg_parse_context("choice", true); + auto result = parser.parse(ctx); + t.assert_equal("choices_partial_match_2", true, result.need_more_input()); + }); + + // Choices - Full Match 1 + t.test("choices_full_match_1", [&](testing & t) { + auto parser = build_peg_parser([](common_peg_parser_builder & p) { return p.literal("first") | p.literal("second"); }); + + auto ctx = common_peg_parse_context("first", false); + auto result = parser.parse(ctx); + t.assert_equal("choices_full_match_1", true, result.success()); + }); + + // Choices - Full Match 2 + t.test("choices_full_match_2", [&](testing & t) { + auto parser = build_peg_parser([](common_peg_parser_builder & p) { return p.literal("alpha") | p.literal("beta"); }); + + auto ctx = common_peg_parse_context("beta", false); + auto result = parser.parse(ctx); + t.assert_equal("choices_full_match_2", true, result.success()); + }); + + // Choices - No Match + t.test("choices_no_match", [&](testing & t) { + auto parser = build_peg_parser([](common_peg_parser_builder & p) { return p.literal("good") | p.literal("better"); }); + + auto ctx = common_peg_parse_context("best", false); + auto result = parser.parse(ctx); + t.assert_equal("choices_no_match", true, result.fail()); + }); + + // Zero or More - Partial Match 1 + t.test("zero_or_more_partial_match_1", [&](testing & t) { + auto parser = build_peg_parser([](common_peg_parser_builder & p) { return p.zero_or_more(p.literal("ab")); }); + + auto ctx = common_peg_parse_context("a", true); + auto result = parser.parse(ctx); + t.assert_equal("zero_or_more_partial_match_1", true, result.need_more_input()); + }); + + // Zero or More - Partial Match 2 + t.test("zero_or_more_partial_match_2", [&](testing & t) { + auto parser = build_peg_parser([](common_peg_parser_builder & p) { return p.zero_or_more(p.literal("xy")); }); + + auto ctx = common_peg_parse_context("xyx", true); + auto result = parser.parse(ctx); + t.assert_equal("zero_or_more_partial_match_2", true, result.need_more_input()); + }); + + // Zero or More - Full Match + t.test("zero_or_more_full_match", [&](testing & t) { + auto parser = build_peg_parser([](common_peg_parser_builder & p) { return p.zero_or_more(p.literal("test")); }); + + auto ctx = common_peg_parse_context("test", false); + auto result = parser.parse(ctx); + t.assert_equal("zero_or_more_full_match", true, result.success()); + }); + + // One or More - Partial Match 1 + t.test("one_or_more_partial_match_1", [&](testing & t) { + auto parser = build_peg_parser([](common_peg_parser_builder & p) { return p.one_or_more(p.literal("repeat")); }); + + auto ctx = common_peg_parse_context("rep", true); + auto result = parser.parse(ctx); + t.assert_equal("one_or_more_partial_match_1", true, result.need_more_input()); + }); + + // One or More - Partial Match 2 + t.test("one_or_more_partial_match_2", [&](testing & t) { + auto parser = build_peg_parser([](common_peg_parser_builder & p) { return p.one_or_more(p.literal("ab")); }); + + auto ctx = common_peg_parse_context("aba", true); + auto result = parser.parse(ctx); + t.assert_equal("one_or_more_partial_match_2", true, result.need_more_input()); + }); + + // One or More - Full Match + t.test("one_or_more_full_match", [&](testing & t) { + auto parser = build_peg_parser([](common_peg_parser_builder & p) { return p.one_or_more(p.literal("single")); }); + + auto ctx = common_peg_parse_context("single", false); + auto result = parser.parse(ctx); + t.assert_equal("one_or_more_full_match", true, result.success()); + }); + + // One or More - No Match + t.test("one_or_more_no_match", [&](testing & t) { + auto parser = build_peg_parser([](common_peg_parser_builder & p) { return p.one_or_more(p.literal("()")); }); + + auto ctx = common_peg_parse_context("success", false); + auto result = parser.parse(ctx); + t.assert_equal("one_or_more_no_match", true, result.fail()); + }); + }); + + + t.test("recursive rules", [](testing &t) { + // Test simple number + t.test("simple_number", [](testing &t) { + auto value_parser = build_peg_parser([](common_peg_parser_builder & p) { + p.rule("number", p.chars("0-9")); + p.rule("list", p.literal("[") + p.ref("value") + p.literal("]")); + return p.rule("value", p.ref("number") | p.ref("list")); + }); + + common_peg_parse_context ctx("1", false); + auto result = value_parser.parse(ctx); + + t.assert_equal("result_is_success", true, result.success()); + }); + + // Test simple list + t.test("simple_list", [](testing &t) { + auto value_parser = build_peg_parser([](common_peg_parser_builder & p) { + p.rule("number", p.chars("0-9")); + p.rule("list", p.literal("[") + p.ref("value") + p.literal("]")); + return p.rule("value", p.ref("number") | p.ref("list")); + }); + + common_peg_parse_context ctx("[1]", false); + auto result = value_parser.parse(ctx); + + t.assert_equal("result_is_success", true, result.success()); + }); + + // Test nested list + t.test("nested_list", [](testing &t) { + auto value_parser = build_peg_parser([](common_peg_parser_builder & p) { + p.rule("number", p.chars("0-9")); + p.rule("list", p.literal("[") + p.ref("value") + p.literal("]")); + return p.rule("value", p.ref("number") | p.ref("list")); + }); + + common_peg_parse_context ctx("[[2]]", false); + auto result = value_parser.parse(ctx); + + t.assert_equal("result_is_success", true, result.success()); + }); + + // Test deeply nested list + t.test("deeply_nested_list", [](testing &t) { + auto value_parser = build_peg_parser([](common_peg_parser_builder & p) { + p.rule("number", p.chars("0-9")); + p.rule("list", p.literal("[") + p.ref("value") + p.literal("]")); + return p.rule("value", p.ref("number") | p.ref("list")); + }); + + common_peg_parse_context ctx("[[[3]]]", false); + auto result = value_parser.parse(ctx); + + t.assert_equal("result_is_success", true, result.success()); + }); + + // Test need_more_input match + t.test("need_more_input_match", [](testing &t) { + auto value_parser = build_peg_parser([](common_peg_parser_builder & p) { + p.rule("number", p.chars("0-9")); + p.rule("list", p.literal("[") + p.ref("value") + p.literal("]")); + return p.rule("value", p.ref("number") | p.ref("list")); + }); + + common_peg_parse_context ctx("[[", true); + auto result = value_parser.parse(ctx); + + t.assert_equal("result_is_need_more_input", true, result.need_more_input()); + }); + + // Test no match + t.test("no_match", [](testing &t) { + auto value_parser = build_peg_parser([](common_peg_parser_builder & p) { + p.rule("number", p.chars("0-9")); + p.rule("list", p.literal("[") + p.ref("value") + p.literal("]")); + return p.rule("value", p.ref("number") | p.ref("list")); + }); + + common_peg_parse_context ctx("[a]", false); + auto result = value_parser.parse(ctx); + + t.assert_equal("result_is_fail", true, result.fail()); + }); + }); +} diff --git a/tests/peg-parser/test-gbnf-generation.cpp b/tests/peg-parser/test-gbnf-generation.cpp new file mode 100644 index 00000000000..68857a5e887 --- /dev/null +++ b/tests/peg-parser/test-gbnf-generation.cpp @@ -0,0 +1,250 @@ +#include "tests.h" + +#include "json-schema-to-grammar.h" + +#include + +static std::string trim_leading_space(const std::string & s) { + static const std::regex leading_ws_re = std::regex(R"((^|\n)\s+)"); + return std::regex_replace(s, leading_ws_re, "$1"); +} + +static void assert_gbnf_equal(testing & t, const std::string & expected, const std::string & actual) { + t.assert_equal("gbnf are equal", trim_leading_space(expected), trim_leading_space(actual)); +} + +void test_gbnf_generation(testing &t) { + t.test("literal grammar generation", [](testing &t) { + auto parser = build_peg_parser([](common_peg_parser_builder & p) { + return p.literal("hello"); + }); + + auto gbnf = build_grammar([&](const common_grammar_builder & builder) { + parser.build_grammar(builder); + }); + + assert_gbnf_equal(t, R"""( + root ::= "hello" + space ::= | " " | "\n"{1,2} [ \t]{0,20} + )""", gbnf); + }); + + t.test("char class grammar", [](testing &t) { + auto parser = build_peg_parser([](common_peg_parser_builder & p) { + return p.chars("[a-z]", 1, 1); + }); + + auto gbnf = build_grammar([&](const common_grammar_builder & builder) { + parser.build_grammar(builder); + }); + + assert_gbnf_equal(t, R"""( + root ::= [a-z] + space ::= | " " | "\n"{1,2} [ \t]{0,20} + )""", gbnf); + }); + + t.test("sequence grammar", [](testing &t) { + auto parser = build_peg_parser([](common_peg_parser_builder & p) { + return p.literal("hello") + p.literal(" ") + p.literal("world"); + }); + + auto gbnf = build_grammar([&](const common_grammar_builder & builder) { + parser.build_grammar(builder); + }); + + assert_gbnf_equal(t, R"""( + root ::= "hello" " " "world" + space ::= | " " | "\n"{1,2} [ \t]{0,20} + )""", gbnf); + }); + + t.test("choice grammar", [](testing &t) { + auto parser = build_peg_parser([](common_peg_parser_builder & p) { + return p.literal("cat") | p.literal("dog"); + }); + + auto gbnf = build_grammar([&](const common_grammar_builder & builder) { + parser.build_grammar(builder); + }); + + assert_gbnf_equal(t, R"""( + root ::= "cat" | "dog" + space ::= | " " | "\n"{1,2} [ \t]{0,20} + )""", gbnf); + }); + + t.test("one_or_more grammar", [](testing &t) { + auto parser = build_peg_parser([](common_peg_parser_builder & p) { + return p.one_or_more(p.literal("a")); + }); + + auto gbnf = build_grammar([&](const common_grammar_builder & builder) { + parser.build_grammar(builder); + }); + + assert_gbnf_equal(t, R"""( + root ::= "a"+ + space ::= | " " | "\n"{1,2} [ \t]{0,20} + )""", gbnf); + }); + + t.test("zero_or_more grammar", [](testing &t) { + auto parser = build_peg_parser([](common_peg_parser_builder & p) { + return p.zero_or_more(p.literal("a")); + }); + + auto gbnf = build_grammar([&](const common_grammar_builder & builder) { + parser.build_grammar(builder); + }); + + assert_gbnf_equal(t, R"""( + root ::= "a"* + space ::= | " " | "\n"{1,2} [ \t]{0,20} + )""", gbnf); + }); + + t.test("optional grammar", [](testing &t) { + auto parser = build_peg_parser([](common_peg_parser_builder & p) { + return p.literal("hello") + p.optional(p.literal(" world")); + }); + + auto gbnf = build_grammar([&](const common_grammar_builder & builder) { + parser.build_grammar(builder); + }); + + assert_gbnf_equal(t, R"""( + root ::= "hello" " world"? + space ::= | " " | "\n"{1,2} [ \t]{0,20} + )""", gbnf); + }); + + t.test("until grammar", [](testing &t) { + auto parser = build_peg_parser([](common_peg_parser_builder & p) { + return p.until(""); + }); + + auto gbnf = build_grammar([&](const common_grammar_builder & builder) { + parser.build_grammar(builder); + }); + + assert_gbnf_equal(t, R"""( + root ::= ([^<] | "<" [^/] | "])* + space ::= | " " | "\n"{1,2} [ \t]{0,20} + )""", gbnf); + }); + + t.test("complex expressions with parentheses", [](testing &t) { + auto parser = build_peg_parser([](common_peg_parser_builder & p) { + return p.one_or_more(p.literal("a") | p.literal("b")); + }); + + auto gbnf = build_grammar([&](const common_grammar_builder & builder) { + parser.build_grammar(builder); + }); + + assert_gbnf_equal(t, R"""( + root ::= ("a" | "b")+ + space ::= | " " | "\n"{1,2} [ \t]{0,20} + )""", gbnf); + }); + + t.test("rule references", [](testing &t) { + auto parser = build_peg_parser([](common_peg_parser_builder & p) { + auto digit = p.rule("digit", p.chars("[0-9]", 1, 1)); + return p.one_or_more(digit); + }); + + auto gbnf = build_grammar([&](const common_grammar_builder & builder) { + parser.build_grammar(builder); + }); + + assert_gbnf_equal(t, R"""( + digit ::= [0-9] + root ::= digit+ + space ::= | " " | "\n"{1,2} [ \t]{0,20} + )""", gbnf); + }); + + t.test("escaping in literals", [](testing &t) { + auto parser = build_peg_parser([](common_peg_parser_builder & p) { + return p.literal("hello\nworld\n!"); + }); + + auto gbnf = build_grammar([&](const common_grammar_builder & builder) { + parser.build_grammar(builder); + }); + + assert_gbnf_equal(t, R"""( + root ::= "hello\nworld\n!" + space ::= | " " | "\n"{1,2} [ \t]{0,20} + )""", gbnf); + }); + + t.test("operator<< (whitespace insertion)", [](testing &t) { + auto parser = build_peg_parser([](common_peg_parser_builder & p) { + return p.literal("hello") << p.literal("world"); + }); + + auto gbnf = build_grammar([&](const common_grammar_builder & builder) { + parser.build_grammar(builder); + }); + + assert_gbnf_equal(t, R"""( + root ::= "hello" space "world" + space ::= | " " | "\n"{1,2} [ \t]{0,20} + )""", gbnf); + }); + + t.test("emit only reachable rules", [](testing &t) { + auto parser = build_peg_parser([](common_peg_parser_builder & p) { + p.rule("orphan", p.literal("orphan")); + return p.literal("hello") + p.rule("child", p.literal(" world")); + }); + + auto gbnf = build_grammar([&](const common_grammar_builder & builder) { + parser.build_grammar(builder); + }); + + assert_gbnf_equal(t, R"""( + child ::= " world" + root ::= "hello" child + space ::= | " " | "\n"{1,2} [ \t]{0,20} + )""", gbnf); + }); + + t.test("emit only trigger rules (and references)", [](testing &t) { + auto parser = build_peg_parser([](common_peg_parser_builder & p) { + auto rule1 = p.rule("rule-1", p.literal("a") + p.ref("rule-2")); + p.rule("rule-2", p.literal("b") + p.ref("rule-3"), true); + p.rule("rule-3", p.literal("c") + p.ref("rule-4")); + p.rule("rule-4", p.literal("d"), true); + return rule1; + }); + + auto gbnf = build_grammar([&](const common_grammar_builder & builder) { + parser.build_grammar(builder); + }); + + assert_gbnf_equal(t, R"""( + root ::= rule-1 + rule-1 ::= "a" rule-2 + rule-2 ::= "b" rule-3 + rule-3 ::= "c" rule-4 + rule-4 ::= "d" + space ::= | " " | "\n"{1,2} [ \t]{0,20} + )""", gbnf); + + auto gbnf_lazy = build_grammar([&](const common_grammar_builder & builder) { + parser.build_grammar(builder, true); + }); + + assert_gbnf_equal(t, R"""( + root ::= rule-2 | rule-4 + rule-2 ::= "b" rule-3 + rule-3 ::= "c" rule-4 + rule-4 ::= "d" + space ::= | " " | "\n"{1,2} [ \t]{0,20} + )""", gbnf_lazy); + }); +} diff --git a/tests/peg-parser/test-json-parser.cpp b/tests/peg-parser/test-json-parser.cpp new file mode 100644 index 00000000000..48351cd66ff --- /dev/null +++ b/tests/peg-parser/test-json-parser.cpp @@ -0,0 +1,109 @@ +#include "tests.h" + +void test_json_parser(testing &t) { + // Test parsing a simple JSON object + t.test("simple JSON object parsing", [](testing &t) { + auto json = build_peg_parser([](common_peg_parser_builder & p) { return p.json(); }); + + std::string input = R"({"name": "test", "value": 42, "flag": true})"; + common_peg_parse_context ctx(input); + + auto result = json.parse(ctx); + + t.assert_equal("result_is_success", true, result.success()); + t.assert_equal("result_end", input.size(), result.end); + }); + + // Test parsing a JSON array with mixed types + t.test("JSON array with mixed types", [](testing &t) { + auto json = build_peg_parser([](common_peg_parser_builder & p) { return p.json(); }); + + std::string input = R"([1, "hello", true, null, 3.14])"; + common_peg_parse_context ctx(input); + + auto result = json.parse(ctx); + + t.assert_equal("result_is_success", true, result.success()); + t.assert_equal("result_end", input.size(), result.end); + }); + + // Test parsing nested JSON with objects and arrays + t.test("nested JSON with objects and arrays", [](testing &t) { + auto json = build_peg_parser([](common_peg_parser_builder & p) { return p.json(); }); + + std::string input = + R"({"users": [{"id": 1, "name": "Alice"}, {"id": 2, "name": "Bob"}], "count": 2, "metadata": {"version": "1.0", "tags": ["admin", "user"]}})"; + common_peg_parse_context ctx(input); + + auto result = json.parse(ctx); + + t.assert_equal("result_is_success", true, result.success()); + t.assert_equal("result_end", input.size(), result.end); + }); + + // Test need_more_input() parsing - incomplete object + t.test("need_more_input() parsing - incomplete object", [](testing &t) { + auto json = build_peg_parser([](common_peg_parser_builder & p) { return p.json(); }); + + std::string input = R"({"name": "test", "value": )"; + common_peg_parse_context ctx(input, true); + + auto result = json.parse(ctx); + + t.assert_equal("result_is_need_more_input", true, result.need_more_input()); + }); + + // Test need_more_input() parsing - incomplete array + t.test("need_more_input() parsing - incomplete array", [](testing &t) { + auto json = build_peg_parser([](common_peg_parser_builder & p) { return p.json(); }); + + std::string input = R"([1, 2, 3, )"; + common_peg_parse_context ctx(input, true); + + auto result = json.parse(ctx); + + t.assert_equal("result_is_need_more_input", true, result.need_more_input()); + }); + + // Test need_more_input() parsing - incomplete nested structure + t.test("need_more_input() parsing - incomplete nested structure", [](testing &t) { + auto json = build_peg_parser([](common_peg_parser_builder & p) { return p.json(); }); + + std::string input = R"({"data": {"nested": )"; + common_peg_parse_context ctx(input, true); + + auto result = json.parse(ctx); + + t.assert_equal("result_is_need_more_input", true, result.need_more_input()); + }); + + t.test("object member", [](testing &t) { + auto parser = build_peg_parser([](common_peg_parser_builder & p) { + return p.json_member("name", "\"" + p.chars("[a-z]") + "\""); + }); + + t.test("success", [&](testing &t) { + std::string input = R"("name": "bob")"; + common_peg_parse_context ctx(input, false); + + auto result = parser.parse(ctx); + t.assert_true("success", result.success()); + }); + + t.test("partial", [&](testing &t) { + std::string input = R"("name": "bo)"; + common_peg_parse_context ctx(input, true); + + auto result = parser.parse(ctx); + t.assert_true("need more input", result.need_more_input()); + }); + + t.test("failed", [&](testing &t) { + std::string input = R"([])"; + common_peg_parse_context ctx(input, false); + + auto result = parser.parse(ctx); + t.assert_true("fail", result.fail()); + }); + }); +} diff --git a/tests/peg-parser/test-json-serialization.cpp b/tests/peg-parser/test-json-serialization.cpp new file mode 100644 index 00000000000..a85801060c0 --- /dev/null +++ b/tests/peg-parser/test-json-serialization.cpp @@ -0,0 +1,28 @@ +#include "tests.h" + +void test_json_serialization(testing &t) { + auto original = build_peg_parser([](common_peg_parser_builder & p) { + return "" + p.json() + ""; + }); + + auto json_serialized = original.to_json().dump(); + + t.test("compare before/after", [&](testing &t) { + auto deserialized = common_peg_arena::from_json(nlohmann::json::parse(json_serialized)); + + // Test complex JSON + std::string input = R"({"name": "test", "values": [1, 2, 3], "nested": {"a": true}})"; + common_peg_parse_context ctx1(input); + common_peg_parse_context ctx2(input); + + auto result1 = original.parse(ctx1); + auto result2 = deserialized.parse(ctx2); + + t.assert_equal("both_succeed", result1.success(), result2.success()); + t.assert_equal("same_end_pos", result1.end, result2.end); + }); + + t.bench("deserialize", [&]() { + auto deserialized = common_peg_arena::from_json(nlohmann::json::parse(json_serialized)); + }, 100); +} diff --git a/tests/peg-parser/test-unicode.cpp b/tests/peg-parser/test-unicode.cpp new file mode 100644 index 00000000000..19d9b9e41c5 --- /dev/null +++ b/tests/peg-parser/test-unicode.cpp @@ -0,0 +1,449 @@ +#include "tests.h" + +#include "peg-parser.h" + +#include +#include +#include +#include + +static void assert_result_equal(testing & t, common_peg_parse_result_type expected, common_peg_parse_result_type actual) { + t.assert_equal(common_peg_parse_result_type_name(expected), common_peg_parse_result_type_name(actual)); +} + +static std::string hex_dump(const std::string& str) { + std::ostringstream oss; + for (unsigned char c : str) { + if (std::isprint(c)) { + oss << c; + } else { + oss << "\\x" << std::hex << std::setw(2) << std::setfill('0') << static_cast(c); + } + } + return oss.str(); +} + +void test_unicode(testing &t) { + struct test_case { + std::string input; + std::string expected_text; + common_peg_parse_result_type expected_result; + }; + + t.test("any", [](testing &t) { + std::vector test_cases { + // Valid UTF-8 sequences + {"Hello", "Hello", COMMON_PEG_PARSE_RESULT_SUCCESS}, + {std::string("Caf\xC3\xA9"), std::string("Caf\xC3\xA9"), COMMON_PEG_PARSE_RESULT_SUCCESS}, + {std::string("\xE4\xBD\xA0\xE5\xA5\xBD"), std::string("\xE4\xBD\xA0\xE5\xA5\xBD"), COMMON_PEG_PARSE_RESULT_SUCCESS}, + {std::string("\xF0\x9F\x9A\x80"), std::string("\xF0\x9F\x9A\x80"), COMMON_PEG_PARSE_RESULT_SUCCESS}, + + // Incomplete UTF-8 sequences (partial bytes at end) + {std::string("Caf\xC3"), "Caf", COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT}, + {std::string("\xE4\xBD"), "", COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT}, + {std::string("\xF0\x9F\x9A"), "", COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT}, + + // Invalid/malformed UTF-8 sequences + {std::string("\xFF\xFE"), "", COMMON_PEG_PARSE_RESULT_FAIL}, + {std::string("Hello\x80World"), "Hello", COMMON_PEG_PARSE_RESULT_FAIL}, + {std::string("\xC3\x28"), "", COMMON_PEG_PARSE_RESULT_FAIL}, + }; + + auto parser = build_peg_parser([](common_peg_parser_builder& p) { + return p.sequence({p.one_or_more(p.any()), p.end()}); + }); + + for (size_t i = 0; i < test_cases.size(); i++) { + const auto & tc = test_cases[i]; + std::string test_name = "case " + std::to_string(i) + ": " + hex_dump(tc.input); + + t.test(test_name, [&](testing &t) { + common_peg_parse_context ctx(tc.input, true); + auto result = parser.parse(ctx); + + // Assert result type matches + assert_result_equal(t, tc.expected_result, result.type); + + // Assert matched text if success or need_more_input + if (result.success() || result.need_more_input()) { + std::string matched = tc.input.substr(result.start, result.end - result.start); + t.assert_equal(tc.expected_text, matched); + } + }); + } + }); + + t.test("char classes", [](testing &t) { + t.test("unicode range U+4E00-U+9FFF (CJK)", [](testing &t) { + std::vector test_cases { + // Within range - CJK Unified Ideographs + {std::string("\xE4\xB8\x80"), std::string("\xE4\xB8\x80"), COMMON_PEG_PARSE_RESULT_SUCCESS}, // U+4E00 + {std::string("\xE4\xBD\xA0"), std::string("\xE4\xBD\xA0"), COMMON_PEG_PARSE_RESULT_SUCCESS}, // U+4F60 + {std::string("\xE5\xA5\xBD"), std::string("\xE5\xA5\xBD"), COMMON_PEG_PARSE_RESULT_SUCCESS}, // U+597D + {std::string("\xE9\xBF\xBF"), std::string("\xE9\xBF\xBF"), COMMON_PEG_PARSE_RESULT_SUCCESS}, // U+9FFF + + // Outside range - should fail + {"a", "", COMMON_PEG_PARSE_RESULT_FAIL}, // ASCII + {std::string("\xE4\xB7\xBF"), "", COMMON_PEG_PARSE_RESULT_FAIL}, // U+4DFF (before range) + {std::string("\xEA\x80\x80"), "", COMMON_PEG_PARSE_RESULT_FAIL}, // U+A000 (after range) + + // Incomplete sequences in range + {std::string("\xE4\xB8"), "", COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT}, // Incomplete U+4E00 + {std::string("\xE5\xA5"), "", COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT}, // Incomplete U+597D + }; + + auto parser = build_peg_parser([](common_peg_parser_builder& p) { + return p.sequence({p.chars(R"([\u4E00-\u9FFF])"), p.end()}); + }); + + for (size_t i = 0; i < test_cases.size(); i++) { + const auto & tc = test_cases[i]; + std::string test_name = "case " + std::to_string(i) + ": " + hex_dump(tc.input); + + t.test(test_name, [&](testing &t) { + common_peg_parse_context ctx(tc.input, true); + auto result = parser.parse(ctx); + + // Assert result type matches + assert_result_equal(t, tc.expected_result, result.type); + + // Assert matched text if success or need_more_input + if (result.success() || result.need_more_input()) { + std::string matched = tc.input.substr(result.start, result.end - result.start); + t.assert_equal(tc.expected_text, matched); + } + }); + } + }); + + t.test("unicode range U+1F600-U+1F64F (emoticons)", [](testing &t) { + std::vector test_cases { + // Within range - Emoticons (all 4-byte UTF-8) + {std::string("\xF0\x9F\x98\x80"), std::string("\xF0\x9F\x98\x80"), COMMON_PEG_PARSE_RESULT_SUCCESS}, // U+1F600 + {std::string("\xF0\x9F\x98\x81"), std::string("\xF0\x9F\x98\x81"), COMMON_PEG_PARSE_RESULT_SUCCESS}, // U+1F601 + {std::string("\xF0\x9F\x99\x8F"), std::string("\xF0\x9F\x99\x8F"), COMMON_PEG_PARSE_RESULT_SUCCESS}, // U+1F64F + + // Outside range + {std::string("\xF0\x9F\x97\xBF"), "", COMMON_PEG_PARSE_RESULT_FAIL}, // U+1F5FF (before range) + {std::string("\xF0\x9F\x99\x90"), "", COMMON_PEG_PARSE_RESULT_FAIL}, // U+1F650 (after range) + {std::string("\xF0\x9F\x9A\x80"), "", COMMON_PEG_PARSE_RESULT_FAIL}, // U+1F680 (outside range) + + // Incomplete sequences + {std::string("\xF0\x9F\x98"), "", COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT}, // Incomplete emoji + {std::string("\xF0\x9F"), "", COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT}, // Very incomplete + }; + + auto parser = build_peg_parser([](common_peg_parser_builder& p) { + return p.sequence({p.chars(R"([\U0001F600-\U0001F64F])"), p.end()}); + }); + + for (size_t i = 0; i < test_cases.size(); i++) { + const auto & tc = test_cases[i]; + std::string test_name = "case " + std::to_string(i) + ": " + hex_dump(tc.input); + + t.test(test_name, [&](testing &t) { + common_peg_parse_context ctx(tc.input, true); + auto result = parser.parse(ctx); + + // Assert result type matches + assert_result_equal(t, tc.expected_result, result.type); + + // Assert matched text if success or need_more_input + if (result.success() || result.need_more_input()) { + std::string matched = tc.input.substr(result.start, result.end - result.start); + t.assert_equal(tc.expected_text, matched); + } + }); + } + }); + + t.test("mixed unicode ranges", [](testing &t) { + std::vector test_cases { + // Match CJK + {std::string("\xE4\xB8\x80"), std::string("\xE4\xB8\x80"), COMMON_PEG_PARSE_RESULT_SUCCESS}, // U+4E00 + {std::string("\xE4\xBD\xA0"), std::string("\xE4\xBD\xA0"), COMMON_PEG_PARSE_RESULT_SUCCESS}, // U+4F60 + + // Match emoticons + {std::string("\xF0\x9F\x98\x80"), std::string("\xF0\x9F\x98\x80"), COMMON_PEG_PARSE_RESULT_SUCCESS}, // U+1F600 + + // Match ASCII digits + {"5", "5", COMMON_PEG_PARSE_RESULT_SUCCESS}, + + // Don't match outside any range + {"a", "", COMMON_PEG_PARSE_RESULT_FAIL}, + {std::string("\xF0\x9F\x9A\x80"), "", COMMON_PEG_PARSE_RESULT_FAIL}, // U+1F680 + + // Incomplete + {std::string("\xE4\xB8"), "", COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT}, + {std::string("\xF0\x9F\x98"), "", COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT}, + }; + + auto parser = build_peg_parser([](common_peg_parser_builder& p) { + return p.sequence({p.chars(R"([\u4E00-\u9FFF\U0001F600-\U0001F64F0-9])"), p.end()}); + }); + + for (size_t i = 0; i < test_cases.size(); i++) { + const auto & tc = test_cases[i]; + std::string test_name = "case " + std::to_string(i) + ": " + hex_dump(tc.input); + + t.test(test_name, [&](testing &t) { + common_peg_parse_context ctx(tc.input, true); + auto result = parser.parse(ctx); + + // Assert result type matches + assert_result_equal(t, tc.expected_result, result.type); + + // Assert matched text if success or need_more_input + if (result.success() || result.need_more_input()) { + std::string matched = tc.input.substr(result.start, result.end - result.start); + t.assert_equal(tc.expected_text, matched); + } + }); + } + }); + }); + + t.test("until parser", [](testing &t) { + t.test("ASCII delimiter with Unicode content", [](testing &t) { + std::vector test_cases { + // CJK characters before delimiter + {std::string("\xE4\xBD\xA0\xE5\xA5\xBD"), std::string("\xE4\xBD\xA0\xE5\xA5\xBD"), COMMON_PEG_PARSE_RESULT_SUCCESS}, + + // Emoji before delimiter + {std::string("\xF0\x9F\x98\x80"), std::string("\xF0\x9F\x98\x80"), COMMON_PEG_PARSE_RESULT_SUCCESS}, + + // Mixed content + {std::string("Hello \xE4\xB8\x96\xE7\x95\x8C!"), std::string("Hello \xE4\xB8\x96\xE7\x95\x8C!"), COMMON_PEG_PARSE_RESULT_SUCCESS}, + }; + + auto parser = build_peg_parser([](common_peg_parser_builder& p) { + return p.until(""); + }); + + for (size_t i = 0; i < test_cases.size(); i++) { + const auto & tc = test_cases[i]; + std::string test_name = "case " + std::to_string(i) + ": " + hex_dump(tc.input); + + t.test(test_name, [&](testing &t) { + common_peg_parse_context ctx(tc.input, false); + auto result = parser.parse(ctx); + + assert_result_equal(t, tc.expected_result, result.type); + + if (result.success()) { + std::string matched = tc.input.substr(result.start, result.end - result.start); + t.assert_equal(tc.expected_text, matched); + } + }); + } + }); + + t.test("incomplete UTF-8 at end", [](testing &t) { + std::vector test_cases { + // Incomplete emoji at end, no delimiter + {std::string("content\xF0\x9F\x98"), std::string("content"), COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT}, + + // Incomplete CJK at end, no delimiter + {std::string("hello\xE4\xB8"), std::string("hello"), COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT}, + + // Complete content, no delimiter (should consume all valid UTF-8) + {std::string("\xE4\xBD\xA0\xE5\xA5\xBD"), std::string("\xE4\xBD\xA0\xE5\xA5\xBD"), COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT}, + }; + + auto parser = build_peg_parser([](common_peg_parser_builder& p) { + return p.until(""); + }); + + for (size_t i = 0; i < test_cases.size(); i++) { + const auto & tc = test_cases[i]; + std::string test_name = "case " + std::to_string(i) + ": " + hex_dump(tc.input); + + t.test(test_name, [&](testing &t) { + common_peg_parse_context ctx(tc.input, true); + auto result = parser.parse(ctx); + + assert_result_equal(t, tc.expected_result, result.type); + + if (result.success() || result.need_more_input()) { + std::string matched = tc.input.substr(result.start, result.end - result.start); + t.assert_equal(tc.expected_text, matched); + } + }); + } + }); + + t.test("malformed UTF-8", [](testing &t) { + std::vector test_cases { + // Invalid UTF-8 bytes + {std::string("Hello\xFF\xFE"), "", COMMON_PEG_PARSE_RESULT_FAIL}, + + // Continuation byte without lead byte + {std::string("Hello\x80World"), "", COMMON_PEG_PARSE_RESULT_FAIL}, + + // Invalid continuation byte + {std::string("\xC3\x28"), "", COMMON_PEG_PARSE_RESULT_FAIL}, + }; + + auto parser = build_peg_parser([](common_peg_parser_builder& p) { + return p.until(""); + }); + + for (size_t i = 0; i < test_cases.size(); i++) { + const auto & tc = test_cases[i]; + std::string test_name = "case " + std::to_string(i) + ": " + hex_dump(tc.input); + + t.test(test_name, [&](testing &t) { + common_peg_parse_context ctx(tc.input, false); + auto result = parser.parse(ctx); + + assert_result_equal(t, tc.expected_result, result.type); + }); + } + }); + }); + + t.test("json_string parser", [](testing &t) { + t.test("valid UTF-8 characters", [](testing &t) { + std::vector test_cases { + // ASCII only + {"Hello World\"", "Hello World", COMMON_PEG_PARSE_RESULT_SUCCESS}, + + // 2-byte UTF-8 (accented characters) + {std::string("Caf\xC3\xA9\""), std::string("Caf\xC3\xA9"), COMMON_PEG_PARSE_RESULT_SUCCESS}, + + // 3-byte UTF-8 (CJK) + {std::string("\xE4\xBD\xA0\xE5\xA5\xBD\""), std::string("\xE4\xBD\xA0\xE5\xA5\xBD"), COMMON_PEG_PARSE_RESULT_SUCCESS}, + + // 4-byte UTF-8 (emoji) + {std::string("\xF0\x9F\x98\x80\""), std::string("\xF0\x9F\x98\x80"), COMMON_PEG_PARSE_RESULT_SUCCESS}, + + // Mixed content + {std::string("Hello \xE4\xB8\x96\xE7\x95\x8C!\""), std::string("Hello \xE4\xB8\x96\xE7\x95\x8C!"), COMMON_PEG_PARSE_RESULT_SUCCESS}, + }; + + for (size_t i = 0; i < test_cases.size(); i++) { + const auto & tc = test_cases[i]; + std::string test_name = "case " + std::to_string(i) + ": " + hex_dump(tc.input); + + t.test(test_name, [&](testing &t) { + auto parser = build_peg_parser([](common_peg_parser_builder& p) { + return p.sequence({p.json_string_content(), p.literal("\"")}); + }); + + common_peg_parse_context ctx(tc.input, false); + auto result = parser.parse(ctx); + + assert_result_equal(t, tc.expected_result, result.type); + + if (result.success()) { + std::string matched = tc.input.substr(result.start, result.end - result.start - 1); // -1 to exclude closing quote + t.assert_equal(tc.expected_text, matched); + } + }); + } + }); + + t.test("incomplete UTF-8", [](testing &t) { + std::vector test_cases { + // Incomplete 2-byte sequence + {std::string("Caf\xC3"), std::string("Caf"), COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT}, + + // Incomplete 3-byte sequence + {std::string("Hello\xE4\xB8"), std::string("Hello"), COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT}, + + // Incomplete 4-byte sequence + {std::string("Text\xF0\x9F\x98"), std::string("Text"), COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT}, + + // Incomplete at very start + {std::string("\xE4\xBD"), std::string(""), COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT}, + }; + + for (size_t i = 0; i < test_cases.size(); i++) { + const auto & tc = test_cases[i]; + std::string test_name = "case " + std::to_string(i) + ": " + hex_dump(tc.input); + + t.test(test_name, [&](testing &t) { + auto parser = build_peg_parser([](common_peg_parser_builder& p) { + return p.json_string_content(); + }); + + common_peg_parse_context ctx(tc.input, true); + auto result = parser.parse(ctx); + + assert_result_equal(t, tc.expected_result, result.type); + + if (result.need_more_input()) { + std::string matched = tc.input.substr(result.start, result.end - result.start); + t.assert_equal(tc.expected_text, matched); + } + }); + } + }); + + t.test("malformed UTF-8", [](testing &t) { + std::vector test_cases { + // Invalid UTF-8 bytes + {std::string("Hello\xFF\xFE"), "", COMMON_PEG_PARSE_RESULT_FAIL}, + + // Continuation byte without lead byte + {std::string("Hello\x80World"), "", COMMON_PEG_PARSE_RESULT_FAIL}, + + // Invalid continuation byte + {std::string("\xC3\x28"), "", COMMON_PEG_PARSE_RESULT_FAIL}, + + // Overlong encoding (security issue) + {std::string("\xC0\x80"), "", COMMON_PEG_PARSE_RESULT_FAIL}, + }; + + for (size_t i = 0; i < test_cases.size(); i++) { + const auto & tc = test_cases[i]; + std::string test_name = "case " + std::to_string(i) + ": " + hex_dump(tc.input); + + t.test(test_name, [&](testing &t) { + auto parser = build_peg_parser([](common_peg_parser_builder& p) { + return p.json_string_content(); + }); + + common_peg_parse_context ctx(tc.input, false); + auto result = parser.parse(ctx); + + assert_result_equal(t, tc.expected_result, result.type); + }); + } + }); + + t.test("escape sequences with UTF-8", [](testing &t) { + std::vector test_cases { + // Unicode escape sequence + {"Hello\\u0041\"", "Hello\\u0041", COMMON_PEG_PARSE_RESULT_SUCCESS}, + + // Mix of UTF-8 and escape sequences + {std::string("\xE4\xBD\xA0\\n\xE5\xA5\xBD\""), std::string("\xE4\xBD\xA0\\n\xE5\xA5\xBD"), COMMON_PEG_PARSE_RESULT_SUCCESS}, + + // Escaped quote in UTF-8 string + {std::string("\xE4\xBD\xA0\\\"\xE5\xA5\xBD\""), std::string("\xE4\xBD\xA0\\\"\xE5\xA5\xBD"), COMMON_PEG_PARSE_RESULT_SUCCESS}, + }; + + for (size_t i = 0; i < test_cases.size(); i++) { + const auto & tc = test_cases[i]; + std::string test_name = "case " + std::to_string(i) + ": " + hex_dump(tc.input); + + t.test(test_name, [&](testing &t) { + auto parser = build_peg_parser([](common_peg_parser_builder& p) { + return p.sequence({p.json_string_content(), p.literal("\"")}); + }); + + common_peg_parse_context ctx(tc.input, false); + auto result = parser.parse(ctx); + + assert_result_equal(t, tc.expected_result, result.type); + + if (result.success()) { + std::string matched = tc.input.substr(result.start, result.end - result.start - 1); // -1 to exclude closing quote + t.assert_equal(tc.expected_text, matched); + } + }); + } + }); + }); +} diff --git a/tests/peg-parser/testing.h b/tests/peg-parser/testing.h new file mode 100644 index 00000000000..45ac4ca7842 --- /dev/null +++ b/tests/peg-parser/testing.h @@ -0,0 +1,243 @@ +#pragma once + +#include "common.h" + +#include +#include +#include +#include +#include +#include + +struct testing { + std::ostream &out; + std::vector stack; + std::regex filter; + bool filter_tests = false; + bool throw_exception = false; + bool verbose = false; + int tests = 0; + int assertions = 0; + int failures = 0; + int unnamed = 0; + int exceptions = 0; + + static constexpr std::size_t status_column = 80; + + explicit testing(std::ostream &os = std::cout) : out(os) {} + + std::string indent() const { + if (stack.empty()) { + return ""; + } + return std::string((stack.size() - 1) * 2, ' '); + } + + std::string full_name() const { + return string_join(stack, "."); + } + + void log(const std::string & msg) { + if (verbose) { + out << indent() << " " << msg << "\n"; + } + } + + void set_filter(const std::string & re) { + filter = std::regex(re); + filter_tests = true; + } + + bool should_run() const { + if (filter_tests) { + if (!std::regex_match(full_name(), filter)) { + return false; + } + } + return true; + } + + template + void run_with_exceptions(F &&f, const char *ctx) { + try { + f(); + } catch (const std::exception &e) { + ++failures; + ++exceptions; + out << indent() << "UNHANDLED EXCEPTION (" << ctx << "): " << e.what() << "\n"; + if (throw_exception) { + throw; + } + } catch (...) { + ++failures; + ++exceptions; + out << indent() << "UNHANDLED EXCEPTION (" << ctx << "): unknown\n"; + if (throw_exception) { + throw; + } + } + } + + void print_result(const std::string &label, int new_failures, int new_assertions, const std::string &extra = "") const { + std::string line = indent() + label; + + std::string details; + if (new_assertions > 0) { + if (new_failures == 0) { + details = std::to_string(new_assertions) + " assertion(s)"; + } else { + details = std::to_string(new_failures) + " of " + + std::to_string(new_assertions) + " assertion(s) failed"; + } + } + if (!extra.empty()) { + if (!details.empty()) { + details += ", "; + } + details += extra; + } + + if (!details.empty()) { + line += " (" + details + ")"; + } + + std::string status = (new_failures == 0) ? "[PASS]" : "[FAIL]"; + + if (line.size() + 1 < status_column) { + line.append(status_column - line.size(), ' '); + } else { + line.push_back(' '); + } + + out << line << status << "\n"; + } + + template + void test(const std::string &name, F f) { + stack.push_back(name); + if (!should_run()) { + stack.pop_back(); + return; + } + + ++tests; + out << indent() << name << "\n"; + + int before_failures = failures; + int before_assertions = assertions; + + run_with_exceptions([&] { f(*this); }, "test"); + + int new_failures = failures - before_failures; + int new_assertions = assertions - before_assertions; + + print_result(name, new_failures, new_assertions); + + stack.pop_back(); + } + + template + void test(F f) { + test("test #" + std::to_string(++unnamed), f); + } + + template + void bench(const std::string &name, F f, int iterations = 100) { + stack.push_back(name); + if (!should_run()) { + stack.pop_back(); + return; + } + + ++tests; + out << indent() << "[bench] " << name << "\n"; + + int before_failures = failures; + int before_assertions = assertions; + + using clock = std::chrono::high_resolution_clock; + + std::chrono::microseconds duration(0); + + run_with_exceptions([&] { + for (auto i = 0; i < iterations; i++) { + auto start = clock::now(); + f(); + duration += std::chrono::duration_cast(clock::now() - start); + } + }, "bench"); + + auto avg_elapsed = duration.count() / iterations; + auto avg_elapsed_s = std::chrono::duration_cast>(duration).count() / iterations; + auto rate = (avg_elapsed_s > 0.0) ? (1.0 / avg_elapsed_s) : 0.0; + + int new_failures = failures - before_failures; + int new_assertions = assertions - before_assertions; + + std::string extra = + "n=" + std::to_string(iterations) + + " avg=" + std::to_string(avg_elapsed) + "us" + + " rate=" + std::to_string(int(rate)) + "/s"; + + print_result("[bench] " + name, new_failures, new_assertions, extra); + + stack.pop_back(); + } + + template + void bench(F f, int iterations = 100) { + bench("bench #" + std::to_string(++unnamed), f, iterations); + } + + // Assertions + bool assert_true(bool cond) { + return assert_true("", cond); + } + + bool assert_true(const std::string &msg, bool cond) { + ++assertions; + if (!cond) { + ++failures; + out << indent() << "ASSERT TRUE FAILED"; + if (!msg.empty()) { + out << " : " << msg; + } + out << "\n"; + return false; + } + return true; + } + + template + bool assert_equal(const A &expected, const B &actual) { + return assert_equal("", expected, actual); + } + + template + bool assert_equal(const std::string &msg, const A &expected, const B &actual) { + ++assertions; + if (!(actual == expected)) { + ++failures; + out << indent() << "ASSERT EQUAL FAILED"; + if (!msg.empty()) { + out << " : " << msg; + } + out << "\n"; + + out << indent() << " expected: " << expected << "\n"; + out << indent() << " actual : " << actual << "\n"; + return false; + } + return true; + } + + // Print summary and return an exit code + int summary() const { + out << "\n"; + out << "tests : " << tests << "\n"; + out << "assertions : " << assertions << "\n"; + out << "failures : " << failures << "\n"; + out << "exceptions : " << exceptions << "\n"; + return failures == 0 ? 0 : 1; + } +}; diff --git a/tests/peg-parser/tests.h b/tests/peg-parser/tests.h new file mode 100644 index 00000000000..25727682c8a --- /dev/null +++ b/tests/peg-parser/tests.h @@ -0,0 +1,24 @@ +#pragma once + +// Common includes for all test files +#include +#include +#include + +#include "testing.h" +#include "peg-parser.h" +#include "chat-peg-parser.h" +#include "simple-tokenize.h" + +struct bench_tool_call { + std::string id; + std::string name; + nlohmann::ordered_json args; +}; + +// Test function declarations +void test_basic(testing &t); +void test_json_parser(testing &t); +void test_gbnf_generation(testing &t); +void test_unicode(testing &t); +void test_json_serialization(testing &t); diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 9645d0b3909..7ef7f2ad81e 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -41,12 +41,18 @@ #include #include +#ifdef __EMSCRIPTEN__ +# define N_THREADS 1 +#else +# define N_THREADS std::thread::hardware_concurrency() +#endif + static void init_tensor_uniform(ggml_tensor * tensor, float min = -1.0f, float max = 1.0f) { size_t nels = ggml_nelements(tensor); std::vector data(nels); { // parallel initialization - static const size_t n_threads = std::thread::hardware_concurrency(); + static const size_t n_threads = N_THREADS; // static RNG initialization (revisit if n_threads stops being constant) static std::vector generators = []() { std::random_device rd; @@ -65,15 +71,19 @@ static void init_tensor_uniform(ggml_tensor * tensor, float min = -1.0f, float m } }; - std::vector> tasks; - tasks.reserve(n_threads); - for (size_t i = 0; i < n_threads; i++) { - size_t start = i*nels/n_threads; - size_t end = (i+1)*nels/n_threads; - tasks.push_back(std::async(std::launch::async, init_thread, i, start, end)); - } - for (auto & t : tasks) { - t.get(); + if (n_threads == 1) { + init_thread(0, 0, nels); + } else { + std::vector> tasks; + tasks.reserve(n_threads); + for (size_t i = 0; i < n_threads; i++) { + size_t start = i*nels/n_threads; + size_t end = (i+1)*nels/n_threads; + tasks.push_back(std::async(std::launch::async, init_thread, i, start, end)); + } + for (auto & t : tasks) { + t.get(); + } } } @@ -105,17 +115,23 @@ static void init_tensor_uniform(ggml_tensor * tensor, float min = -1.0f, float m }; const size_t min_blocks_per_thread = 1; - const size_t n_threads = std::min(std::thread::hardware_concurrency()/2, - std::max(1, n_blocks / min_blocks_per_thread)); - std::vector> tasks; - tasks.reserve(n_threads); - for (size_t i = 0; i < n_threads; i++) { - size_t start = i*n_blocks/n_threads; - size_t end = (i+1)*n_blocks/n_threads; - tasks.push_back(std::async(std::launch::async, quantize_thread, start, end)); - } - for (auto & t : tasks) { - t.get(); + const size_t n_quant_threads = std::min(std::max(N_THREADS/2, 1), + std::max(1, n_blocks / min_blocks_per_thread)); + + if (n_quant_threads == 1) { + // single-threaded quantization: do all blocks in the current thread + quantize_thread(0, n_blocks); + } else { + std::vector> tasks; + tasks.reserve(n_quant_threads); + for (size_t i = 0; i < n_quant_threads; i++) { + size_t start = i*n_blocks/n_quant_threads; + size_t end = (i+1)*n_blocks/n_quant_threads; + tasks.push_back(std::async(std::launch::async, quantize_thread, start, end)); + } + for (auto & t : tasks) { + t.get(); + } } } ggml_backend_tensor_set(tensor, dataq.data(), 0, dataq.size()); @@ -8363,7 +8379,7 @@ int main(int argc, char ** argv) { auto ggml_backend_set_n_threads_fn = (ggml_backend_set_n_threads_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_n_threads"); if (ggml_backend_set_n_threads_fn) { // TODO: better value for n_threads - ggml_backend_set_n_threads_fn(backend, std::thread::hardware_concurrency()); + ggml_backend_set_n_threads_fn(backend, N_THREADS); } size_t free, total; // NOLINT diff --git a/tests/test-chat-peg-parser.cpp b/tests/test-chat-peg-parser.cpp new file mode 100644 index 00000000000..fbbb9c82efb --- /dev/null +++ b/tests/test-chat-peg-parser.cpp @@ -0,0 +1,768 @@ +#include +#include +#include + +#include "chat-parser.h" +#include "chat-peg-parser.h" +#include "chat.h" +#include "common.h" +#include "json-schema-to-grammar.h" +#include "peg-parser.h" +#include "peg-parser/testing.h" +#include "peg-parser/simple-tokenize.h" +#include "nlohmann/json.hpp" + +using json = nlohmann::ordered_json; + +static json create_tools(); +static void test_example_native(testing & t); +static void test_example_qwen3_coder(testing & t); +static void test_command7_parser_compare(testing & t); + +int main(int argc, char *argv[]) { + testing t(std::cout); + if (argc >= 2) { + t.set_filter(argv[1]); + } + + const char * verbose = getenv("LLAMA_TEST_VERBOSE"); + if (verbose) { + t.verbose = std::string(verbose) == "1"; + } + + t.test("native", test_example_native); + t.test("qwen3 coder", test_example_qwen3_coder); + t.test("comparison", test_command7_parser_compare); + + return t.summary(); +} + +static json create_tools() { + json tools = json::array(); + + json tool_weather = { + {"type", "function"}, + {"function", { + {"name", "get_current_weather"}, + {"description", "Get the current weather in a given location"}, + {"parameters", { + {"type", "object"}, + {"properties", { + {"location", { + {"type", "string"}, + {"description", "The city and state, e.g. San Francisco, CA"} + }}, + {"unit", { + {"type", "string"}, + {"enum", {"celsius", "fahrenheit"}}, + {"description", "The temperature unit to use. Infer this from the users location."} + }} + }}, + {"required", {"location", "unit"}}, + }}, + }} + }; + tools.push_back(tool_weather); + + json tool_forecast = { + {"type", "function"}, + {"function", { + {"name", "get_forecast"}, + {"description", "Get the weather forecast for a given location"}, + {"parameters", { + {"type", "object"}, + {"properties", { + {"location", { + {"type", "string"}, + {"description", "The city and state, e.g. San Francisco, CA"} + }}, + {"unit", { + {"type", "string"}, + {"enum", {"celsius", "fahrenheit"}}, + {"description", "The temperature unit to use. Infer this from the users location."} + }}, + {"days", { + {"type", "integer"}, + {"description", "Number of days to forecast (1-10)"}, + {"minimum", 1}, + {"maximum", 10} + }} + }}, + {"required", {"location", "unit"}}, + }}, + }} + }; + tools.push_back(tool_forecast); + + json tool_search = { + {"type", "function"}, + {"function", { + {"name", "search_knowledge_base"}, + {"description", "Search the internal technical documentation knowledge base."}, + {"parameters", { + {"type", "object"}, + {"properties", { + {"query", { + {"type", "string"}, + {"description", "The search query string."} + }}, + {"max_results", { + {"type", "integer"}, + {"description", "The maximum number of results to return."}, + {"default", 5} + }}, + {"category", { + {"type", "string"}, + {"enum", {"api", "troubleshooting", "billing", "general"}}, + {"description", "Filter search by specific category."} + }} + }}, + {"required", {"query", "category"}}, + {"additionalProperties", false} + }}, + {"strict", true} + }} + }; + tools.push_back(tool_search); + + return tools; +} + +struct tool_argument { + std::string name; + std::string type; + bool is_required; + json schema; +}; + +struct tool_definition { + std::string name; + std::vector arguments; + json schema; +}; + +// Test fictitious model output that emits arguments as JSON. +static void test_example_native(testing & t) { + struct test_case { + // Parameters + std::string name; + json tools; + common_chat_tool_choice tool_choice; + common_reasoning_format reasoning_format; + json json_schema; + bool parallel_tool_calls; + bool thinking_forced_open; + std::string input; + + // Expect + std::string expect_reasoning; + std::string expect_content; + std::vector expect_tool_calls; + }; + + auto build_parser = [](const test_case & tc) { + return build_chat_peg_native_parser([&](common_chat_peg_native_builder & p) { + auto reasoning_in_content = (tc.reasoning_format == COMMON_REASONING_FORMAT_NONE); + auto reasoning = p.eps(); + if (tc.thinking_forced_open) { + // If thinking is forced open, expect a closing tag + reasoning = p.reasoning(p.until("")) + "" + p.space(); + } else { + // Otherwise, optionally accept thinking wrapped in tags + reasoning = p.optional("" + p.reasoning(p.until("")) + "" + p.space()); + } + + // tool calling parser + if (tc.tools.is_array() && !tc.tools.empty()) { + auto tools = p.choice(); + for (const auto & tool : tc.tools) { + const auto & function = tool.at("function"); + std::string name = function.at("name"); + const auto & schema = function.at("parameters"); + + auto tool_name = p.json_member("name", "\"" + p.tool_name(p.literal(name)) + "\""); + auto tool_args = p.json_member("arguments", p.tool_args(p.schema(p.json(), "tool-" + name + "-schema", schema))); + + tools |= p.rule("tool-" + name, p.tool_open(p.literal("{")) << tool_name << "," << tool_args << "}"); + }; + + auto parallel_calls = p.eps(); + if (tc.parallel_tool_calls) { + parallel_calls = p.zero_or_more("," << tools); + } + + auto tool_call = p.trigger_rule("tool-call", + p.sequence({ + p.literal("["), + tools, + parallel_calls, + p.literal("]") + }) + ); + + return p.sequence({ + (reasoning_in_content ? p.eps() : reasoning), + p.content(p.until("")), + p.optional(p.space() + tool_call), + p.space(), + p.end() + }); + } + + // response_format parser + if (tc.json_schema.is_object() && !tc.json_schema.empty()) { + return p.sequence({ + (reasoning_in_content ? p.eps() : reasoning), + p.content(p.schema(p.json(), "response-output", tc.json_schema)), + p.space(), + p.end() + }); + } + + // Content-only parser + return p.sequence({ + (reasoning_in_content ? p.eps() : reasoning), + p.content(p.rest()), + p.end() + }); + }); + }; + + std::vector test_cases = std::vector{ + { + /* .name = */ "content with thinking_forced_open = false", + /* .tools = */ {}, + /* .tool_choice = */ COMMON_CHAT_TOOL_CHOICE_NONE, + /* .reasoning_format = */ COMMON_REASONING_FORMAT_AUTO, + /* .json_schema = */ {}, + /* .parallel_tool_calls = */ false, + /* .thinking_forced_open = */ false, + /* .input = */ ( + "The user said hello, I must say hello back\nHello" + ), + /* .expect_reasoning = */ "The user said hello, I must say hello back", + /* .expect_content = */ "Hello", + /* .expect_tool_calls = */ {}, + }, + { + /* .name = */ "content with thinking_forced_open = false and no reasoning", + /* .tools = */ {}, + /* .tool_choice = */ COMMON_CHAT_TOOL_CHOICE_NONE, + /* .reasoning_format = */ COMMON_REASONING_FORMAT_AUTO, + /* .json_schema = */ {}, + /* .parallel_tool_calls = */ false, + /* .thinking_forced_open = */ false, + /* .input = */ ( + "Hello" + ), + /* .expect_reasoning = */ "", + /* .expect_content = */ "Hello", + /* .expect_tool_calls = */ {}, + }, + { + /* .name = */ "content with thinking_forced_open = false and reasoning_format = none", + /* .tools = */ {}, + /* .tool_choice = */ COMMON_CHAT_TOOL_CHOICE_NONE, + /* .reasoning_format = */ COMMON_REASONING_FORMAT_NONE, + /* .json_schema = */ {}, + /* .parallel_tool_calls = */ false, + /* .thinking_forced_open = */ true, + /* .input = */ ( + "The user said hello, I must say hello back\nHello" + ), + /* .expect_reasoning = */ "", + /* .expect_content = */ "The user said hello, I must say hello back\nHello", + /* .expect_tool_calls = */ {}, + }, + { + /* .name = */ "content with thinking_forced_open = true", + /* .tools = */ {}, + /* .tool_choice = */ COMMON_CHAT_TOOL_CHOICE_NONE, + /* .reasoning_format = */ COMMON_REASONING_FORMAT_AUTO, + /* .json_schema = */ {}, + /* .parallel_tool_calls = */ false, + /* .thinking_forced_open = */ true, + /* .input = */ ( + "The user said hello, I must say hello back\nHello" + ), + /* .expect_reasoning = */ "The user said hello, I must say hello back", + /* .expect_content = */ "Hello", + /* .expect_tool_calls = */ {}, + }, + { + /* .name = */ "content with thinking_forced_open = true and reasoning_format = none", + /* .tools = */ {}, + /* .tool_choice = */ COMMON_CHAT_TOOL_CHOICE_NONE, + /* .reasoning_format = */ COMMON_REASONING_FORMAT_NONE, + /* .json_schema = */ {}, + /* .parallel_tool_calls = */ false, + /* .thinking_forced_open = */ true, + /* .input = */ ( + "The user said hello, I must say hello back\nHello" + ), + /* .expect_reasoning = */ "", + /* .expect_content = */ "The user said hello, I must say hello back\nHello", + /* .expect_tool_calls = */ {}, + }, + { + /* .name = */ "tools with tool_choice = auto and no parallel_tool_calls", + /* .tools = */ create_tools(), + /* .tool_choice = */ COMMON_CHAT_TOOL_CHOICE_AUTO, + /* .reasoning_format = */ COMMON_REASONING_FORMAT_AUTO, + /* .json_schema = */ {}, + /* .parallel_tool_calls = */ false, + /* .thinking_forced_open = */ true, + /* .input = */ ( + "I must get the weather in New York\n" + "[" + R"({"name": "get_current_weather", "arguments": {"location": "New York City, NY", "unit": "fahrenheit"}})" + "]" + ), + /* .expect_reasoning = */ "I must get the weather in New York", + /* .expect_content = */ "", + /* .expect_tool_calls = */ {{ + /* .name = */ "get_current_weather", + /* .arguments = */ R"({"location": "New York City, NY", "unit": "fahrenheit"})", + /* .id = */ "", + }}, + }, + { + /* .name = */ "tools with tool_choice = auto and parallel_tool_calls", + /* .tools = */ create_tools(), + /* .tool_choice = */ COMMON_CHAT_TOOL_CHOICE_AUTO, + /* .reasoning_format = */ COMMON_REASONING_FORMAT_AUTO, + /* .json_schema = */ {}, + /* .parallel_tool_calls = */ true, + /* .thinking_forced_open = */ true, + /* .input = */ ( + "I must get the weather in New York and San Francisco and a 3 day forecast of each.\nLet me search that for you." + "[" + R"({"name": "get_current_weather", "arguments": {"location": "New York City, NY", "unit": "fahrenheit"}})" + ", " + R"({"name": "get_current_weather", "arguments": {"location": "San Francisco, CA", "unit": "fahrenheit"}})" + ", " + R"({"name": "get_forecast", "arguments": {"location": "New York City, NY", "unit": "fahrenheit", "days": 3}})" + ", " + R"({"name": "get_forecast", "arguments": {"location": "San Francisco, CA", "unit": "fahrenheit", "days": 3}})" + "]" + ), + /* .expect_reasoning = */ "I must get the weather in New York and San Francisco and a 3 day forecast of each.", + /* .expect_content = */ "Let me search that for you.", + /* .expect_tool_calls = */ {{ + /* .name = */ "get_current_weather", + /* .arguments = */ R"({"location": "New York City, NY", "unit": "fahrenheit"})", + /* .id = */ "", + }, { + /* .name = */ "get_current_weather", + /* .arguments = */ R"({"location": "San Francisco, CA", "unit": "fahrenheit"})", + /* .id = */ "", + }, { + /* .name = */ "get_forecast", + /* .arguments = */ R"({"location": "New York City, NY", "unit": "fahrenheit", "days": 3})", + /* .id = */ "", + }, { + /* .name = */ "get_forecast", + /* .arguments = */ R"({"location": "San Francisco, CA", "unit": "fahrenheit", "days": 3})", + /* .id = */ "", + }}, + }, + { + /* .name = */ "response_format with thinking_forced_open = true", + /* .tools = */ {}, + /* .tool_choice = */ COMMON_CHAT_TOOL_CHOICE_NONE, + /* .reasoning_format = */ COMMON_REASONING_FORMAT_AUTO, + /* .json_schema = */ { + {"type", "object"}, + {"properties", { + {"invoice_number", {{"type", "string"}}}, + {"amount", {{"type", "number"}}}, + {"due_date", {{"type", "string"}}} + }}, + {"required", {"invoice_number", "amount", "due_date"}} + }, + /* .parallel_tool_calls = */ false, + /* .thinking_forced_open = */ true, + /* .input = */ ( + "I must produce the invoice in the requested format\n" + R"({"invoice_number": "INV-2025-001", "amount": 1250.50, "due_date": "2025-12-31"})" + ), + /* .expect_reasoning = */ "I must produce the invoice in the requested format", + /* .expect_content = */ R"({"invoice_number": "INV-2025-001", "amount": 1250.50, "due_date": "2025-12-31"})", + /* .expect_tool_calls = */ {}, + }, + }; + + for (const auto & tc : test_cases) { + t.test(tc.name, [&](testing & t) { + auto parser = build_parser(tc); + auto lazy = !tc.tools.empty() && tc.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED; + auto grammar = build_grammar([&](const common_grammar_builder & builder) { + for (auto const & def : tc.tools) { + auto function = def.at("function"); + auto parameters = function.at("parameters"); + builder.resolve_refs(parameters); + }; + parser.build_grammar(builder, lazy); + }); + + t.log("Grammar:"); + for (auto const & line : string_split(grammar, "\n")) { + t.log(line); + } + + common_peg_parse_context ctx(tc.input, false); + auto result = parser.parse(ctx); + + t.assert_true("success", result.success()); + + common_chat_msg msg; + auto mapper = common_chat_peg_native_mapper(msg); + mapper.from_ast(ctx.ast, result); + + t.assert_equal("content equal", tc.expect_content, msg.content); + t.assert_equal("reasoning equal", tc.expect_reasoning, msg.reasoning_content); + t.assert_equal("number of tool calls", tc.expect_tool_calls.size(), msg.tool_calls.size()); + for (auto i = 0u; i < std::min(tc.expect_tool_calls.size(), msg.tool_calls.size()); i++) { + t.assert_equal("tool name", tc.expect_tool_calls[i].name, msg.tool_calls[i].name); + t.assert_equal("tool args", tc.expect_tool_calls[i].arguments, msg.tool_calls[i].arguments); + } + }); + } +} + +static void test_example_qwen3_coder(testing & t) { + auto tools = create_tools(); + auto parser = build_chat_peg_constructed_parser([&](common_chat_peg_constructed_builder & p) { + auto content = p.rule("content", p.content(p.until(""))); + + std::vector tool_parsers; + for (auto const & def : tools) { + auto function = def.at("function"); + std::string name = function.at("name"); + auto parameters = function.at("parameters"); + auto properties = parameters.at("properties"); + + std::set required_properties; + if (function.contains("required")) { + function.at("required").get_to(required_properties); + } + + std::vector arg_parsers; + for (const auto & [param_name, param_schema] : properties.items()) { + bool is_required = required_properties.find(param_name) != required_properties.end(); + auto type = param_schema.value("type", "object"); + + auto arg = p.tool_arg(p.sequence({ + p.tool_arg_open(""), + (type == "string" ? + p.tool_arg_string_value( + p.schema( + p.until_one_of({ + "\n\n" + }), + "tool-" + name + "-arg-" + param_name + "-schema", + param_schema, + true + ) + ) : p.tool_arg_json_value( + p.schema( + p.json(), + "tool-" + name + "-arg-" + param_name + "-schema", + param_schema + ) + ) + ), + p.tool_arg_close( + "\n" + + p.peek(p.literal("")) + ) + })); + + arg_parsers.push_back(is_required ? + p.rule("tool-" + name + "-arg-" + param_name, arg) : + p.optional(p.rule("tool-" + name + "-arg-" + param_name, arg))); + } + + tool_parsers.push_back(p.rule("tool-" + name, + p.tool_open("") + << p.sequence(arg_parsers) + << p.tool_close(p.literal("")) + )); + }; + + auto tool_call = p.trigger_rule("tool-call", + "" + << p.choice(tool_parsers) + << "" + ); + + return content + p.zero_or_more(p.space() + tool_call) + p.end(); + }); + + auto grammar = build_grammar([&](const common_grammar_builder & builder) { + for (auto const & def : tools) { + auto function = def.at("function"); + auto parameters = function.at("parameters"); + builder.resolve_refs(parameters); + }; + parser.build_grammar(builder); + }); + + t.log("Grammar:"); + for (auto const & line : string_split(grammar, "\n")) { + t.log(line); + } + + t.test("incremental parsing", [&](testing &t) { + std::string input = + "Let me search the knowledge base for cat pictures." + "\n" + "\n" + "cat pictures\n" + "general\n" + "\n" + ""; + + std::vector tokens = simple_tokenize(input); + + common_chat_msg prev; + for (auto it = tokens.begin(); it != tokens.end(); it++) { + std::string in = std::accumulate(tokens.begin(), it + 1, std::string()); + + common_peg_parse_context ctx(in, it + 1 < tokens.end()); + + auto result = parser.parse(ctx); + if (!t.assert_equal("not fail", false, result.fail())) { + t.log(in.substr(0, result.end) + "[failed->]" + in.substr(result.end)); + } + + common_chat_msg msg; + auto mapper = common_chat_peg_constructed_mapper(msg); + mapper.from_ast(ctx.ast, result); + + //t.log("Input: " + input); + t.log("==========================================="); + t.log("Iteration " + std::to_string(in.size())); + t.log("Reasoning: " + msg.reasoning_content); + t.log("Content : " + msg.content); + for (const auto & tc : msg.tool_calls) { + t.log("Tool name: " + tc.name); + t.log("Tool args: " + tc.arguments); + } + + try { + // This shouldn't emit any runtime errors + auto diffs = common_chat_msg_diff::compute_diffs(prev, msg); + } catch(const std::exception & e) { + t.log(in.substr(0, result.end) + "[failed->]" + in.substr(result.end)); + t.assert_true(std::string("failed with ") + e.what(), false); + } + + prev = msg; + } + }); +} + +void test_command7_parser_compare(testing & t) { + auto parser = build_chat_peg_native_parser([](common_chat_peg_native_builder & p) { + auto thinking = p.reasoning_block( + "<|START_THINKING|>" << p.reasoning(p.until("<|END_THINKING|>")) << "<|END_THINKING|>"); + + auto response = "<|START_RESPONSE|>" << p.content(p.until("<|END_RESPONSE|>")) << "<|END_RESPONSE|>"; + + auto tool_call_id = p.atomic("\"tool_call_id\"" << (":" << ("\"" + p.tool_id(p.json_string_content()) + "\""))); + auto tool_call_name = p.atomic("\"tool_name\"" << (":" << ("\"" + p.tool_name(p.json_string_content()) + "\""))); + auto tool_call_args = "\"parameters\"" << (":" << p.tool_args(p.json())); + + auto tool_call_fields = p.rule("tool-call-fields", tool_call_id | tool_call_name | tool_call_args); + auto tool_call = p.rule("tool-call", p.tool( + p.tool_open(p.literal("{")) + << tool_call_fields + << p.zero_or_more( p.literal(",") << tool_call_fields) + << p.tool_close(p.literal("}")) + )); + + auto tool_calls = p.rule("tool-calls", + "<|START_ACTION|>" + << ("[" << tool_call << p.zero_or_more(p.literal(",") << tool_call) << "]") + << "<|END_ACTION|>"); + + return p.optional(thinking) << (tool_calls | response) + p.end(); + }); + + auto test_current = [&](const common_peg_arena & p, const std::string & input, bool is_partial, bool print_results) { + common_peg_parse_context ctx(input, is_partial); + auto result = p.parse(ctx); + + common_chat_msg msg; + auto mapper = common_chat_peg_native_mapper(msg); + mapper.from_ast(ctx.ast, result); + + if (print_results) { + std::cout << "== Parsed (new) ==\n"; + std::cout << "=== Reasoning ===\n"; + std::cout << msg.reasoning_content << "\n"; + std::cout << "\n\n=== Content ===\n"; + std::cout << msg.content << "\n"; + std::cout << "\n\n=== Tool Calls ===\n"; + for (const auto & tc : msg.tool_calls) { + std::cout << "id: " << tc.id << "\n"; + std::cout << "name: " << tc.name << "\n"; + std::cout << "args: " << tc.arguments << "\n"; + } + } + }; + + auto test_legacy = [&](const std::string & input, bool need_more_input, bool print_results) { + // Original common_chat_combinator_parser taken from chat.cpp + common_chat_msg_parser builder( + input, + /* .is_partial = */ need_more_input, + { + /* .format = */ COMMON_CHAT_FORMAT_GENERIC, + /* .reasoning_format = */ COMMON_REASONING_FORMAT_AUTO, + /* .reasoning_in_content = */ false, + /* .thinking_forced_open = */ false, + } + ); + + builder.try_parse_reasoning("<|START_THINKING|>", "<|END_THINKING|>"); + + static const common_regex start_action_regex("<\\|START_ACTION\\|>"); + static const common_regex end_action_regex("<\\|END_ACTION\\|>"); + static const common_regex start_response_regex("<\\|START_RESPONSE\\|>"); + static const common_regex end_response_regex("<\\|END_RESPONSE\\|>"); + + if (auto res = builder.try_find_regex(start_action_regex)) { + // If we didn't extract thoughts, prelude includes them. + auto tool_calls = builder.consume_json_with_dumped_args({ { "parameters" } }); + for (const auto & tool_call : tool_calls.value) { + std::string name = tool_call.contains("tool_name") ? tool_call.at("tool_name") : ""; + std::string id = tool_call.contains("tool_call_id") ? tool_call.at("tool_call_id") : ""; + std::string arguments = tool_call.contains("parameters") ? tool_call.at("parameters") : ""; + if (!builder.add_tool_call(name, id, arguments) || tool_calls.is_partial) { + throw common_chat_msg_partial_exception("incomplete tool call"); + } + } + if (tool_calls.is_partial) { + throw common_chat_msg_partial_exception("incomplete tool call"); + } + builder.consume_regex(end_action_regex); + } else if (auto res = builder.try_find_regex(start_response_regex)) { + if (!builder.try_find_regex(end_response_regex)) { + builder.add_content(builder.consume_rest()); + throw common_chat_msg_partial_exception(end_response_regex.str()); + } + } else { + builder.add_content(builder.consume_rest()); + } + + if (print_results) { + std::cout << "== Parsed (legacy) ==\n"; + std::cout << "=== Reasoning ===\n"; + std::cout << builder.result().reasoning_content << "\n"; + std::cout << "\n\n=== Content ===\n"; + std::cout << builder.result().content << "\n"; + std::cout << "\n\n=== Tool Calls ===\n"; + for (const auto & tc : builder.result().tool_calls) { + std::cout << "id: " << tc.id << "\n"; + std::cout << "name: " << tc.name << "\n"; + std::cout << "args: " << tc.arguments << "\n"; + } + } + }; + + std::string reasoning = "To plan an effective trip to Japan that includes both historical sites and modern attractions within a " + "budget of $4000 for a two-week stay, we need to:\n\n" + "1. Identify key historical sites and modern attractions in Japan.\n" + "2. Find affordable accommodation options that provide a balance between comfort and cost.\n" + "3. Determine the best modes of transportation for getting around Japan.\n" + "4. Create a day-by-day itinerary that ensures the user gets to see a variety of attractions without " + "overspending.\n" + "5. Provide a detailed cost breakdown that includes accommodation, transportation, meals, and entry fees " + "to attractions."; + + std::vector> tool_calls = {{ + "call_0", + "plan_trip", + nlohmann::json::parse(R"({ + "destination": "Japan", + "duration": 14, + "budget": 4000, + "interests": ["historical sites", "modern attractions"], + "accommodation_preferences": "affordable", + "transportation_preferences": "efficient", + "meal_preferences": "local cuisine" + })") + }}; + + std::vector tokens; + + // Build tokens + if (!reasoning.empty()) { + auto tokenized = simple_tokenize(reasoning); + tokens.emplace_back("<|START_THINKING|>"); + tokens.insert(tokens.end(), tokenized.begin(), tokenized.end()); + tokens.emplace_back("<|END_THINKING|>"); + } + + if (!tool_calls.empty()) { + tokens.emplace_back("<|START_ACTION|>"); + + auto json = nlohmann::json::array(); + for (const auto & tc : tool_calls) { + auto tc_json = nlohmann::json::object(); + tc_json["tool_call_id"] = std::get<0>(tc); + tc_json["tool_name"] = std::get<1>(tc); + tc_json["parameters"] = std::get<2>(tc); + json.push_back(tc_json); + } + + auto tokenized = simple_tokenize(json.dump(-1, ' ', true)); + tokens.insert(tokens.end(), tokenized.begin(), tokenized.end()); + + tokens.emplace_back("<|END_ACTION|>"); + } + + std::string input = std::accumulate(tokens.begin(), tokens.end(), std::string()); + + // Run tests + t.test("legacy_parse", [&](testing & /* t */) { + test_legacy(input, false, false); + }); + + t.test("current_parse", [&](testing & /* t */) { + test_current(parser, input, false, false); + }); + + // Run benchmarks + t.bench("legacy_parse_benchmark complete", [&]() { + test_legacy(input, false, false); + }); + + t.bench("legacy_parse_benchmark incremental", [&]() { + std::string in; + for (auto i = 0u; i < tokens.size(); i++) { + in += tokens[i]; + + try { + test_legacy(in, i + 1 < tokens.size(), false); + } catch (common_chat_msg_partial_exception & /* e */) { + // Do nothing, this is expected + } + } + }, 20); + + t.bench("current_parse_benchmark complete", [&]() { + test_current(parser, input, false, false); + }, 100); + + t.bench("current_parse_benchmark incremental", [&]() { + std::string in; + for (auto i = 0u; i < tokens.size(); i++) { + in += tokens[i]; + test_current(parser, in, i + 1 < tokens.size(), false); + } + }, 20); +} diff --git a/tests/test-json-schema-to-grammar.cpp b/tests/test-json-schema-to-grammar.cpp index 1e568219d21..6a4bd8fb4dc 100755 --- a/tests/test-json-schema-to-grammar.cpp +++ b/tests/test-json-schema-to-grammar.cpp @@ -1375,7 +1375,7 @@ int main() { try { tc.verify(json_schema_to_grammar(nlohmann::ordered_json::parse(tc.schema), true)); tc.verify_status(SUCCESS); - } catch (const std::runtime_error & ex) { + } catch (const std::invalid_argument & ex) { fprintf(stderr, "Error: %s\n", ex.what()); tc.verify_status(FAILURE); } diff --git a/tests/test-peg-parser.cpp b/tests/test-peg-parser.cpp new file mode 100644 index 00000000000..220745d0293 --- /dev/null +++ b/tests/test-peg-parser.cpp @@ -0,0 +1,25 @@ +#include +#include +#include + +#include "peg-parser/tests.h" + +int main(int argc, char *argv[]) { + testing t(std::cout); + if (argc >= 2) { + t.set_filter(argv[1]); + } + + const char * verbose = getenv("LLAMA_TEST_VERBOSE"); + if (verbose) { + t.verbose = std::string(verbose) == "1"; + } + + t.test("basic", test_basic); + t.test("unicode", test_unicode); + t.test("json", test_json_parser); + t.test("gbnf", test_gbnf_generation); + t.test("serialization", test_json_serialization); + + return t.summary(); +} diff --git a/tools/mtmd/clip-impl.h b/tools/mtmd/clip-impl.h index cd47865bf4a..4c3fc2df268 100644 --- a/tools/mtmd/clip-impl.h +++ b/tools/mtmd/clip-impl.h @@ -39,6 +39,11 @@ #define KEY_FEATURE_LAYER "clip.vision.feature_layer" #define KEY_PROJ_SCALE_FACTOR "clip.vision.projector.scale_factor" #define KEY_SPATIAL_MERGE_SIZE "clip.vision.spatial_merge_size" +#define KEY_PHI3_HD_ORDER "clip.vision.hd_transform_order" +#define KEY_PHI3_NUM_IMG_TOKENS "clip.vision.num_img_tokens" +#define KEY_PHI3_NUM_CROPS "clip.vision.num_crops" +#define KEY_PHI3_USE_HD "clip.vision.use_hd_transform" +#define KEY_PHI3_WITH_SEP "clip.vision.with_learnable_separator" #define KEY_IS_DEEPSTACK_LAYERS "clip.vision.is_deepstack_layers" #define KEY_MM_PATCH_MERGE_TYPE "clip.vision.mm_patch_merge_type" @@ -86,6 +91,8 @@ #define TN_MVLM_PROJ_BLOCK "mm.model.mb_block.%d.block.%d.%s" #define TN_MVLM_PROJ_PEG "mm.model.peg.%d.%s" #define TN_IMAGE_NEWLINE "model.image_newline" +#define TN_PHI3_GLB_GN "v.glb_GN" // phi3v +#define TN_PHI3_SUB_GN "v.sub_GN" // phi3v #define TN_MM_INP_NORM "mm.input_norm.weight" #define TN_MM_INP_NORM_B "mm.input_norm.bias" #define TN_MM_INP_PROJ "mm.input_projection.weight" // gemma3 @@ -156,6 +163,7 @@ enum projector_type { PROJECTOR_TYPE_LIGHTONOCR, PROJECTOR_TYPE_COGVLM, PROJECTOR_TYPE_JANUS_PRO, + PROJECTOR_TYPE_PHI3_V, PROJECTOR_TYPE_UNKNOWN, }; @@ -182,6 +190,7 @@ static std::map PROJECTOR_TYPE_NAMES = { { PROJECTOR_TYPE_LIGHTONOCR,"lightonocr"}, { PROJECTOR_TYPE_COGVLM, "cogvlm"}, { PROJECTOR_TYPE_JANUS_PRO, "janus_pro"}, + { PROJECTOR_TYPE_PHI3_V, "phi3_v"}, }; static projector_type clip_projector_type_from_string(const std::string & str) { diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index ea89259f92d..5deb191db29 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -206,6 +206,13 @@ struct clip_hparams { int32_t custom_image_min_tokens = -1; int32_t custom_image_max_tokens = -1; + // phi3v + bool use_hd = false; + bool with_learnable_separator = false; + std::string hd_order = "sub_glb"; + int32_t num_img_tokens = 144; + int32_t num_crops = -1; + void set_limit_image_tokens(int n_tokens_min, int n_tokens_max) { const int cur_merge = n_merge == 0 ? 1 : n_merge; const int patch_area = patch_size * patch_size * cur_merge * cur_merge; @@ -399,6 +406,15 @@ struct clip_model { ggml_tensor * mm_boi = nullptr; ggml_tensor * mm_eoi = nullptr; + // phi3v + ggml_tensor * mm_glb_GN = nullptr; // global separator + ggml_tensor * mm_sub_GN = nullptr; // sub-image separator + bool phi3_setup_done = false; + + // Pre-calculated projected vectors + std::vector phi3_proj_glb_GN; + std::vector phi3_proj_sub_GN; + bool audio_has_avgpool() const { return proj_type == PROJECTOR_TYPE_QWEN2A || proj_type == PROJECTOR_TYPE_VOXTRAL; @@ -428,6 +444,7 @@ struct clip_ctx { int max_nodes = 8192; ggml_backend_sched_ptr sched; clip_flash_attn_type flash_attn_type = CLIP_FLASH_ATTN_TYPE_AUTO; + bool is_allocated = false; // for debugging bool debug_graph = false; @@ -469,6 +486,9 @@ struct clip_ctx { if (ctx_params.image_max_tokens > 0) { model.hparams.custom_image_max_tokens = ctx_params.image_max_tokens; } + if (ctx_params.num_crops > 0) { + model.hparams.num_crops = ctx_params.num_crops; + } backend_ptrs.push_back(backend_cpu); backend_buft.push_back(ggml_backend_get_default_buffer_type(backend_cpu)); @@ -1999,6 +2019,70 @@ struct clip_graph { return gf; } + static struct ggml_tensor * ggml_phi3v_hd_merge( + struct ggml_context * ctx, + struct ggml_tensor * image_features, + int h_crop, + int w_crop, + int patch_size + ) { + int n_images = image_features->ne[3]; + const int n_channels = image_features->ne[0];; + const int patch_size_half = patch_size/2; + + struct ggml_tensor * t = image_features; + t = ggml_reshape_4d(ctx, t, n_channels, 2, patch_size_half, patch_size * n_images); + t = ggml_reshape_4d(ctx, t, n_channels * 2, patch_size_half, 2, patch_size_half * n_images); + t = ggml_permute(ctx, t, 0, 2, 1, 3); + t = ggml_cont(ctx, t); + + return t; + } + + ggml_cgraph * build_phi3v() { + ggml_tensor * inp = build_inp(); + int n_patches_per_crop = inp->ne[1]; + int patch_size_transformed = sqrt(n_patches_per_crop); + + int num_crops = inp->ne[2]; + inp = ggml_reshape_3d(ctx0, inp, n_embd, n_patches_per_crop, num_crops); + + ggml_tensor * cls = model.class_embedding; + cls = ggml_reshape_3d(ctx0, cls, n_embd, 1, 1); + cls = ggml_repeat(ctx0, cls, ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_embd, 1, num_crops)); + + inp = ggml_concat(ctx0, cls, inp, 1); + inp = ggml_add(ctx0, inp, model.position_embeddings); + + inp = ggml_reshape_2d(ctx0, inp, n_embd, (n_patches_per_crop + 1) * num_crops); + + ggml_tensor * cur = build_vit(inp, (n_patches_per_crop + 1) * num_crops, + NORM_TYPE_NORMAL, hparams.ffn_op, nullptr, nullptr); + + + cur = ggml_reshape_3d(ctx0, cur, n_embd, n_patches_per_crop + 1, num_crops); + + cur = ggml_view_3d(ctx0, cur, + n_embd, n_patches_per_crop, num_crops, + cur->nb[1], cur->nb[2], + cur->nb[1]); + + cur = ggml_reshape_4d(ctx0, cur, n_embd, patch_size_transformed, patch_size_transformed, num_crops); + cur = ggml_cont(ctx0, cur); + + cur = ggml_phi3v_hd_merge(ctx0, cur, 1, 1, patch_size_transformed); + cur = ggml_reshape_2d(ctx0, cur, hparams.n_ff, hparams.num_img_tokens * num_crops); + + ggml_tensor * final_emb = ggml_mul_mat(ctx0, model.mm_0_w, cur); + final_emb = ggml_add(ctx0, final_emb, model.mm_0_b); + final_emb = ggml_gelu(ctx0, final_emb); + final_emb = ggml_mul_mat(ctx0, model.mm_2_w, final_emb); + final_emb = ggml_add(ctx0, final_emb, model.mm_2_b); + + ggml_build_forward_expand(gf, final_emb); + return gf; + } + private: // // utility functions @@ -2532,6 +2616,10 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 { res = graph.build_cogvlm(); } break; + case PROJECTOR_TYPE_PHI3_V: + { + res = graph.build_phi3v(); + } break; default: { res = graph.build_llava(); @@ -2861,6 +2949,24 @@ struct clip_model_loader { hparams.ffn_op = FFN_GELU_ERF; log_ffn_op = "gelu_erf"; // temporary solution for logging } break; + case PROJECTOR_TYPE_PHI3_V: + { + get_bool(KEY_PHI3_USE_HD, hparams.use_hd, false); + if (!hparams.use_hd) { + LOG_WRN("%s: Phi-3-Vision model missing %s=true, assuming HD transform is required\n", __func__, KEY_PHI3_USE_HD); + } + + get_string(KEY_PHI3_HD_ORDER, hparams.hd_order, false); + if (hparams.hd_order != "sub_glb") { + throw std::runtime_error(string_format("%s: unsupported HD transform order: %s (only 'sub_glb' supported)\n", __func__, hparams.hd_order.c_str())); + } + + get_u32(KEY_PHI3_NUM_IMG_TOKENS, hparams.num_img_tokens, false); + if (hparams.num_crops == -1) { + get_u32(KEY_PHI3_NUM_CROPS, hparams.num_crops, false); + } + get_bool(KEY_PHI3_WITH_SEP, hparams.with_learnable_separator, false); + } break; default: break; } @@ -3248,6 +3354,20 @@ struct clip_model_loader { model.mm_1_w = get_tensor(string_format(TN_LLAVA_PROJ, 1, "weight")); model.mm_1_b = get_tensor(string_format(TN_LLAVA_PROJ, 1, "bias")); } break; + case PROJECTOR_TYPE_PHI3_V: + { + // Load MLP weights: mm.model.mlp.0.weight / bias + model.mm_0_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 0, "weight")); + model.mm_0_b = get_tensor(string_format(TN_MVLM_PROJ_MLP, 0, "bias")); + + // Load MLP weights: mm.model.mlp.2.weight / bias + model.mm_2_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 2, "weight")); + model.mm_2_b = get_tensor(string_format(TN_MVLM_PROJ_MLP, 2, "bias")); + + // Load Separators + model.mm_glb_GN = get_tensor(TN_PHI3_GLB_GN); + model.mm_sub_GN = get_tensor(TN_PHI3_SUB_GN); + } break; default: GGML_ASSERT(false && "unknown projector type"); } @@ -3305,12 +3425,30 @@ struct clip_model_loader { }; static void warmup(clip_ctx & ctx_clip) { + // create a fake batch + const auto & hparams = ctx_clip.model.hparams; + clip_image_f32_batch batch; + clip_image_f32_ptr img(clip_image_f32_init()); + if (ctx_clip.model.modality == CLIP_MODALITY_VISION) { + img->nx = hparams.warmup_image_size; + img->ny = hparams.warmup_image_size; + LOG_INF("%s: warmup with image size = %d x %d\n", __func__, img->nx, img->ny); + } else { + img->nx = hparams.warmup_audio_size; + img->ny = hparams.n_mel_bins; + LOG_INF("%s: warmup with audio size = %d\n", __func__, img->nx); + } + batch.entries.push_back(std::move(img)); + warmup(ctx_clip, batch); + } + + static void warmup(clip_ctx & ctx_clip, const clip_image_f32_batch & batch) { support_info_graph info; if (ctx_clip.flash_attn_type == CLIP_FLASH_ATTN_TYPE_AUTO) { // try to enable flash attention to see if it's supported ctx_clip.flash_attn_type = CLIP_FLASH_ATTN_TYPE_ENABLED; - info = alloc_compute_meta(ctx_clip); + info = alloc_compute_meta(ctx_clip, batch); if (!info.fattn && info.fattn_op) { auto op = info.fattn_op; LOG_WRN("%s: *****************************************************************\n", __func__); @@ -3329,15 +3467,17 @@ struct clip_model_loader { LOG_WRN("%s: please report this on github as an issue\n", __func__); LOG_WRN("%s: *****************************************************************\n", __func__); ctx_clip.flash_attn_type = CLIP_FLASH_ATTN_TYPE_DISABLED; - alloc_compute_meta(ctx_clip); + alloc_compute_meta(ctx_clip, batch); } } else { - info = alloc_compute_meta(ctx_clip); + info = alloc_compute_meta(ctx_clip, batch); if (!info.fattn && ctx_clip.flash_attn_type == CLIP_FLASH_ATTN_TYPE_ENABLED) { LOG_WRN("%s: flash attention is not supported by the current backend; falling back to CPU (performance will be degraded)\n", __func__); } } + ctx_clip.is_allocated = true; // mark buffers as allocated + LOG_INF("%s: flash attention is %s\n", __func__, (ctx_clip.flash_attn_type == CLIP_FLASH_ATTN_TYPE_ENABLED) ? "enabled" : "disabled"); @@ -3369,24 +3509,9 @@ struct clip_model_loader { } } - static support_info_graph alloc_compute_meta(clip_ctx & ctx_clip) { - const auto & hparams = ctx_clip.model.hparams; + static support_info_graph alloc_compute_meta(clip_ctx & ctx_clip, const clip_image_f32_batch & batch) { ctx_clip.buf_compute_meta.resize(ctx_clip.max_nodes * ggml_tensor_overhead() + ggml_graph_overhead()); - // create a fake batch - clip_image_f32_batch batch; - clip_image_f32_ptr img(clip_image_f32_init()); - if (ctx_clip.model.modality == CLIP_MODALITY_VISION) { - img->nx = hparams.warmup_image_size; - img->ny = hparams.warmup_image_size; - LOG_INF("%s: warmup with image size = %d x %d\n", __func__, img->nx, img->ny); - } else { - img->nx = hparams.warmup_audio_size; - img->ny = hparams.n_mel_bins; - LOG_INF("%s: warmup with audio size = %d\n", __func__, img->nx); - } - batch.entries.push_back(std::move(img)); - ggml_cgraph * gf = clip_image_build_graph(&ctx_clip, batch); ggml_backend_sched_reserve(ctx_clip.sched.get(), gf); @@ -4208,6 +4333,77 @@ struct llava_uhd { } }; +struct phi3v_hd { + struct slice_instructions { + clip_image_size overview_size; + clip_image_size grid_size; + std::vector crops; + }; + + static int padding_336(int x) { + return (int)(std::ceil((float)x / 336.0f) * 336); + } + + static clip_image_size calc_hd_transform_size(int width, int height, int hd_num) { + bool transposed = false; + if (width < height) { + std::swap(width, height); + transposed = true; + } + float ratio = (float)width / (float)height; + + int scale = 1; + while (scale * std::ceil(scale / ratio) <= hd_num) { + scale++; + } + scale--; + if (scale < 1) scale = 1; + + + int new_w = scale * 336; + int new_h = (int)((float)new_w / ratio); + + clip_image_size res; + res.width = padding_336(new_w); + res.height = padding_336(new_h); + + if (transposed) std::swap(res.width, res.height); + return res; + } + + static std::vector transform(const clip_image_u8 * img, int num_crops, int & out_grid_x, int & out_grid_y) { + std::vector output; + + // 1. HD Transform (Resize + Pad) + clip_image_size hd_size = calc_hd_transform_size(img->nx, img->ny, num_crops); + clip_image_u8 hd_img; + img_tool::resize(*img, hd_img, hd_size, img_tool::RESIZE_ALGO_BICUBIC, true, {255, 255, 255}); + + out_grid_x = hd_size.width / 336; + out_grid_y = hd_size.height / 336; + + // 2. Slice into 336x336 patches + // Iterate Y then X (Row-Major) + for (int y = 0; y < hd_size.height; y += 336) { + for (int x = 0; x < hd_size.width; x += 336) { + clip_image_u8_ptr slice(clip_image_u8_init()); + + // Logic: Copy 336x336 area at (x,y) + img_tool::crop(hd_img, *slice, x, y, 336, 336); + output.push_back(std::move(slice)); + } + } + + // 3. Global Image (Resize to 336x336) + clip_image_u8_ptr global(clip_image_u8_init()); + // Global uses Bicubic + img_tool::resize(*img, *global, {336, 336}, img_tool::RESIZE_ALGO_BICUBIC, true, {255, 255, 255}); + output.push_back(std::move(global)); + + return output; + } + }; + // returns the normalized float tensor for llava-1.5, for spatial_unpad with anyres processing for llava-1.6 it returns the normalized image patch tensors as a vector // res_imgs memory is being allocated here, previous allocations will be freed if found bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, struct clip_image_f32_batch * res_imgs) { @@ -4423,7 +4619,21 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str } } } break; + case PROJECTOR_TYPE_PHI3_V: + { + int max_crops = ctx->model.hparams.num_crops; + int gx, gy; + auto imgs = phi3v_hd::transform(img, max_crops, gx, gy); + ctx->model.hparams.image_crop_resolution = (gy << 16) | gx; + for (auto & crop : imgs) { + clip_image_f32_ptr res(clip_image_f32_init()); + normalize_image_u8_to_f32(*crop, *res, ctx->model.hparams.image_mean, ctx->model.hparams.image_std); + res_imgs->entries.push_back(std::move(res)); + } + res_imgs->grid_x = gx; + res_imgs->grid_y = gy; + } break; default: LOG_ERR("%s: unsupported projector type %d\n", __func__, ctx->proj_type()); return false; @@ -4604,6 +4814,10 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im { n_patches += 2; // for BOI and EOI token embeddings } break; + case PROJECTOR_TYPE_PHI3_V: + { + n_patches = ctx->model.hparams.num_img_tokens; + } break; default: GGML_ABORT("unsupported projector type"); } @@ -4630,6 +4844,11 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima return false; // only support batch size of 1 } + // if buffers are not allocated, we need to do a warmup run to allocate them + if (!ctx->is_allocated) { + clip_model_loader::warmup(*ctx, *imgs_c_ptr); + } + // build the inference graph ctx->debug_print_tensors.clear(); ggml_backend_sched_reset(ctx->sched.get()); @@ -4955,6 +5174,13 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima } set_input_i32("pos_w", pos_data); } break; + case PROJECTOR_TYPE_PHI3_V: + { + // do nothing + // Phi-3 uses learned position embeddings which are added + // inside the graph (build_phi3v -> build_vit), + // so no external input tensors are needed. + } break; default: GGML_ABORT("Unknown projector type"); } @@ -5045,6 +5271,8 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) { return ctx->model.mm_2_w->ne[1]; case PROJECTOR_TYPE_COGVLM: return ctx->model.mm_4h_to_h_w->ne[1]; + case PROJECTOR_TYPE_PHI3_V: + return ctx->model.mm_2_b->ne[0]; // 3072 default: GGML_ABORT("Unknown projector type"); } @@ -5071,6 +5299,10 @@ bool clip_is_llava(const struct clip_ctx * ctx) { return ctx->model.hparams.has_llava_projector; } +// [NEW] Phi-3-Vision Helper Implementation +bool clip_is_phi3v(const struct clip_ctx * ctx) { + return ctx->proj_type() == PROJECTOR_TYPE_PHI3_V; +} bool clip_is_gemma3(const struct clip_ctx * ctx) { return ctx->proj_type() == PROJECTOR_TYPE_GEMMA3; } @@ -5120,3 +5352,166 @@ void clip_image_f32_batch_add_mel(struct clip_image_f32_batch * batch, int n_mel batch->entries.push_back(clip_image_f32_ptr(audio)); batch->is_audio = true; } + +static void clip_phi3_setup(clip_ctx * ctx) { + if (ctx->model.phi3_setup_done) return; + + LOG_INF("%s: pre-computing Phi-3 special tokens...\n", __func__); + + // Setup tiny graph + struct ggml_init_params params = { 1024*1024, NULL, true }; + struct ggml_context * ctx0 = ggml_init(params); + struct ggml_cgraph * gf = ggml_new_graph(ctx0); + + // Helper to build MLP graph for a single vector + auto build_mlp = [&](ggml_tensor* input) { + ggml_tensor* cur = ggml_mul_mat(ctx0, ctx->model.mm_0_w, input); + if (ctx->model.mm_0_b) { + ggml_tensor* b = ctx->model.mm_0_b; + cur = ggml_add(ctx0, cur, b); + } + cur = ggml_gelu(ctx0, cur); + cur = ggml_mul_mat(ctx0, ctx->model.mm_2_w, cur); + if (ctx->model.mm_2_b) { + ggml_tensor* b = ctx->model.mm_2_b; + cur = ggml_add(ctx0, cur, b); + } + return cur; + }; + + // 1. Project Global Separator (glb_GN) + ggml_tensor* res_glb = nullptr; + if (ctx->model.mm_glb_GN) { + res_glb = build_mlp(ctx->model.mm_glb_GN); + ggml_build_forward_expand(gf, res_glb); + } + + // 2. Project Sub/Newline Separator (sub_GN) + ggml_tensor* res_sub = nullptr; + if (ctx->model.mm_sub_GN) { + res_sub = build_mlp(ctx->model.mm_sub_GN); + ggml_build_forward_expand(gf, res_sub); + } + + ggml_backend_sched_reset(ctx->sched.get()); + ggml_backend_sched_alloc_graph(ctx->sched.get(), gf); + ggml_backend_sched_graph_compute(ctx->sched.get(), gf); + + int dim = clip_n_mmproj_embd(ctx); + + if (res_glb) { + ctx->model.phi3_proj_glb_GN.resize(dim); + ggml_backend_tensor_get(res_glb, ctx->model.phi3_proj_glb_GN.data(), 0, dim * sizeof(float)); + } + + if (res_sub) { + ctx->model.phi3_proj_sub_GN.resize(dim); + ggml_backend_tensor_get(res_sub, ctx->model.phi3_proj_sub_GN.data(), 0, dim * sizeof(float)); + } + + ggml_free(ctx0); + ggml_backend_sched_reset(ctx->sched.get()); + + ctx->model.phi3_setup_done = true; +} + +bool clip_image_batch_encode_phi3(struct clip_ctx * ctx, int n_threads, const struct clip_image_f32_batch * imgs, float * vec) { + if (!ctx || !imgs || !vec) return false; + + clip_phi3_setup(ctx); + + if (ctx->model.phi3_proj_sub_GN.empty() || ctx->model.phi3_proj_glb_GN.empty()) { + fprintf(stderr, "%s: Error - Phi-3 separators not initialized.\n", __func__); + return false; + } + + const auto & entries = imgs->entries; + int n_crops = entries.size(); + + if (n_crops < 1) return false; + + int dim = clip_n_mmproj_embd(ctx); + + int w_crop = imgs->grid_x; + int h_crop = imgs->grid_y; + int n_sub_images = w_crop * h_crop; + + if (n_sub_images > n_crops - 1) { + return false; + } + + const int sub_crop_tokens = ctx->model.hparams.num_img_tokens; + const int grid_side = (int)sqrt(sub_crop_tokens); // Due to the 2 x 2 hd_transform + + std::vector crop_output(sub_crop_tokens * dim); + + std::vector all_local_crops; + if (n_sub_images > 0) { + all_local_crops.resize(n_sub_images * sub_crop_tokens * dim); + } + + float* dest = vec; + + for (int i = 0; i < n_sub_images; ++i) { + + bool ok = clip_image_encode(ctx, n_threads, entries[i].get(), crop_output.data()); + if (!ok) return false; + + // Copy into the storage buffer linearly + memcpy(all_local_crops.data() + (i * sub_crop_tokens * dim), + crop_output.data(), + sub_crop_tokens * dim * sizeof(float)); + } + + if (n_sub_images > 0) { + for (int row_global = 0; row_global < h_crop * grid_side; ++row_global) { + int crop_y = row_global / grid_side; + int internal_y = row_global % grid_side; + for (int crop_x = 0; crop_x < w_crop; ++crop_x) { + int crop_idx = crop_y * w_crop + crop_x; + + float* src = all_local_crops.data() + + (crop_idx * sub_crop_tokens * dim) + + (internal_y * grid_side * dim); + + memcpy(dest, src, grid_side * dim * sizeof(float)); + dest += (grid_side * dim); + } + + if (!ctx->model.phi3_proj_sub_GN.empty()) { + memcpy(dest, ctx->model.phi3_proj_sub_GN.data(), dim * sizeof(float)); + } else { + memset(dest, 0, dim * sizeof(float)); + } + dest += dim; + } + } + + if (!ctx->model.phi3_proj_glb_GN.empty()) { + memcpy(dest, ctx->model.phi3_proj_glb_GN.data(), dim * sizeof(float)); + } else { + memset(dest, 0, dim * sizeof(float)); + } + dest += dim; + + { + bool ok = clip_image_encode(ctx, n_threads, entries.back().get(), crop_output.data()); + if (!ok) return false; + + float* src = crop_output.data(); + + for (int r = 0; r < grid_side; ++r) { + memcpy(dest, src, grid_side * dim * sizeof(float)); + dest += (grid_side * dim); + src += (grid_side * dim); + + if (!ctx->model.phi3_proj_sub_GN.empty()) { + memcpy(dest, ctx->model.phi3_proj_sub_GN.data(), dim * sizeof(float)); + } else { + memset(dest, 0, dim * sizeof(float)); + } + dest += dim; + } + } + return true; +} \ No newline at end of file diff --git a/tools/mtmd/clip.h b/tools/mtmd/clip.h index e8aeb2066c6..d7469f8d532 100644 --- a/tools/mtmd/clip.h +++ b/tools/mtmd/clip.h @@ -1,6 +1,7 @@ #pragma once #include "ggml.h" +#include "mtmd.h" #include #include @@ -34,6 +35,7 @@ struct clip_context_params { enum clip_flash_attn_type flash_attn_type; int image_min_tokens; int image_max_tokens; + int num_crops; bool warmup; }; @@ -105,6 +107,15 @@ bool clip_is_glm(const struct clip_ctx * ctx); bool clip_is_qwen2vl(const struct clip_ctx * ctx); bool clip_is_llava(const struct clip_ctx * ctx); bool clip_is_gemma3(const struct clip_ctx * ctx); +bool clip_is_phi3v(const struct clip_ctx * ctx); + +// Handles looping, separator injection, and stitching internally. +MTMD_API bool clip_image_batch_encode_phi3( + struct clip_ctx * ctx, + int n_threads, + const struct clip_image_f32_batch * imgs, + float * vec +); bool clip_encode_float_image (struct clip_ctx * ctx, int n_threads, float * img, int h, int w, float * vec); diff --git a/tools/mtmd/mtmd-cli.cpp b/tools/mtmd/mtmd-cli.cpp index b5bbc6536b5..485d664f2f8 100644 --- a/tools/mtmd/mtmd-cli.cpp +++ b/tools/mtmd/mtmd-cli.cpp @@ -73,6 +73,7 @@ struct mtmd_cli_context { common_sampler * smpl; llama_batch batch; int n_batch; + int n_crops; mtmd::bitmaps bitmaps; @@ -96,6 +97,7 @@ struct mtmd_cli_context { n_threads = params.cpuparams.n_threads; batch = llama_batch_init(1, 0, 1); // batch for next token generation n_batch = params.n_batch; + n_crops = params.num_crops; if (!model || !lctx) { exit(1); @@ -139,6 +141,7 @@ struct mtmd_cli_context { mparams.warmup = params.warmup; mparams.image_min_tokens = params.image_min_tokens; mparams.image_max_tokens = params.image_max_tokens; + mparams.num_crops = params.num_crops; ctx_vision.reset(mtmd_init_from_file(clip_path, model, mparams)); if (!ctx_vision.get()) { LOG_ERR("Failed to load vision model from %s\n", clip_path); diff --git a/tools/mtmd/mtmd.cpp b/tools/mtmd/mtmd.cpp index d06fa42e616..b4312e78b56 100644 --- a/tools/mtmd/mtmd.cpp +++ b/tools/mtmd/mtmd.cpp @@ -4,6 +4,7 @@ #include "mtmd-audio.h" #include "llama.h" +#include // fix problem with std::min and std::max #if defined(_WIN32) @@ -111,6 +112,7 @@ mtmd_context_params mtmd_context_params_default() { /* warmup */ true, /* image_min_tokens */ -1, /* image_max_tokens */ -1, + /* num_crops */ -1, }; return params; } @@ -178,6 +180,7 @@ struct mtmd_context { /* flash_attn_type */ CLIP_FLASH_ATTN_TYPE_AUTO, /* image_min_tokens */ ctx_params.image_min_tokens, /* image_max_tokens */ ctx_params.image_max_tokens, + /* num_crops */ ctx_params.num_crops, /* warmup */ ctx_params.warmup, }; @@ -306,6 +309,11 @@ struct mtmd_context { img_beg = "<|im_start|>"; img_end = "<|im_end|>"; + } else if (proj == PROJECTOR_TYPE_PHI3_V) { + img_beg = ""; + img_end = ""; + tok_row_end_trail = false; + ov_img_first = false; } else if (proj == PROJECTOR_TYPE_LFM2) { img_beg = "<|image_start|>"; img_end = "<|image_end|>"; @@ -537,9 +545,37 @@ struct mtmd_tokenizer { LOG_ERR("Unable to preprocess image\n"); return 2; } + // Phi-3-Vision Token Calculation Logic + if (clip_is_phi3v(ctx->ctx_v)) { + const int n_col = batch_f32.grid_x; + const int n_row = batch_f32.grid_y; + const int n_tokens_per_crop = clip_n_output_tokens_x(ctx->ctx_v, batch_f32.entries[0].get()); + const int n_token_for_global_crop = (int)std::sqrt((n_tokens_per_crop)); + + int n_sub_images = n_col * n_row; + size_t local_tokens = (n_sub_images + 1) * n_tokens_per_crop; + + if (n_sub_images == 0) local_tokens = 0; + + size_t global_tokens = (n_row + 1) * n_token_for_global_crop; + size_t separator_tokens = 1; + size_t n_tokens = local_tokens + separator_tokens + global_tokens; - // handle llava-uhd style preprocessing - if ( + mtmd_image_tokens_ptr image_tokens(new mtmd_image_tokens); + image_tokens->nx = n_tokens; + image_tokens->ny = 1; + image_tokens->batch_f32 = std::move(batch_f32); + image_tokens->id = bitmap->id; + + mtmd_input_chunk chunk{ + MTMD_INPUT_CHUNK_TYPE_IMAGE, + {}, + std::move(image_tokens), + nullptr, + }; + cur.entries.emplace_back(std::move(chunk)); + } + else if ( ctx->slice_tmpl == MTMD_SLICE_TMPL_MINICPMV_2_5 || ctx->slice_tmpl == MTMD_SLICE_TMPL_MINICPMV_2_6 || ctx->slice_tmpl == MTMD_SLICE_TMPL_LLAMA4 @@ -812,7 +848,16 @@ int32_t mtmd_encode(mtmd_context * ctx, const mtmd_image_tokens * image_tokens) ctx->image_embd_v.resize(image_tokens->n_tokens() * n_mmproj_embd); bool ok = false; - if (clip_is_llava(ctx_clip) + if (clip_is_phi3v(ctx_clip)) { + // Delegate the entire stitching logic to clip.cpp + ok = clip_image_batch_encode_phi3( + ctx_clip, + ctx->n_threads, + &image_tokens->batch_f32, + ctx->image_embd_v.data() + ); + } + else if (clip_is_llava(ctx_clip) || clip_is_minicpmv(ctx_clip) || clip_is_glm(ctx_clip)) { // TODO @ngxson : llava does not support batched encoding ; this should be fixed inside clip_image_batch_encode() diff --git a/tools/mtmd/mtmd.h b/tools/mtmd/mtmd.h index b3df24c299d..46f033eb37a 100644 --- a/tools/mtmd/mtmd.h +++ b/tools/mtmd/mtmd.h @@ -87,6 +87,7 @@ struct mtmd_context_params { // limit number of image tokens, only for vision models with dynamic resolution int image_min_tokens; // minimum number of tokens for image input (default: read from metadata) int image_max_tokens; // maximum number of tokens for image input (default: read from metadata) + int num_crops; // Phi-3.5 Vison max number of crops }; MTMD_API const char * mtmd_default_marker(void); diff --git a/tools/server/CMakeLists.txt b/tools/server/CMakeLists.txt index fb71c7aa7be..1aa659a9066 100644 --- a/tools/server/CMakeLists.txt +++ b/tools/server/CMakeLists.txt @@ -2,11 +2,6 @@ set(TARGET llama-server) include_directories(${CMAKE_CURRENT_SOURCE_DIR} ${CMAKE_CURRENT_BINARY_DIR}) -if (MINGW) - # fix: https://github.com/ggml-org/llama.cpp/actions/runs/9651004652/job/26617901362?pr=8006 - add_compile_definitions(_WIN32_WINNT=${GGML_WIN_VER}) -endif() - if (NOT LLAMA_HTTPLIB) message(FATAL_ERROR "LLAMA_HTTPLIB is OFF, cannot build llama-server. Hint: to skip building server, set -DLLAMA_BUILD_SERVER=OFF") endif() diff --git a/tools/server/public/index.html.gz b/tools/server/public/index.html.gz index b911b6e769e..4527cb33567 100644 Binary files a/tools/server/public/index.html.gz and b/tools/server/public/index.html.gz differ diff --git a/tools/server/server-common.cpp b/tools/server/server-common.cpp index f48ea5b62a5..cfdd0c656f4 100644 --- a/tools/server/server-common.cpp +++ b/tools/server/server-common.cpp @@ -11,6 +11,7 @@ #include #include +#include json format_error_response(const std::string & message, const enum error_type type) { std::string type_str; @@ -774,6 +775,65 @@ json oaicompat_completion_params_parse(const json & body) { return llama_params; } +// media_path always end with '/', see arg.cpp +static void handle_media( + std::vector & out_files, + json & media_obj, + const std::string & media_path) { + std::string url = json_value(media_obj, "url", std::string()); + if (string_starts_with(url, "http")) { + // download remote image + // TODO @ngxson : maybe make these params configurable + common_remote_params params; + params.headers.push_back("User-Agent: llama.cpp/" + build_info); + params.max_size = 1024 * 1024 * 10; // 10MB + params.timeout = 10; // seconds + SRV_INF("downloading image from '%s'\n", url.c_str()); + auto res = common_remote_get_content(url, params); + if (200 <= res.first && res.first < 300) { + SRV_INF("downloaded %zu bytes\n", res.second.size()); + raw_buffer data; + data.insert(data.end(), res.second.begin(), res.second.end()); + out_files.push_back(data); + } else { + throw std::runtime_error("Failed to download image"); + } + + } else if (string_starts_with(url, "file://")) { + if (media_path.empty()) { + throw std::invalid_argument("file:// URLs are not allowed unless --media-path is specified"); + } + // load local image file + std::string file_path = url.substr(7); // remove "file://" + raw_buffer data; + if (!fs_validate_filename(file_path, true)) { + throw std::invalid_argument("file path is not allowed: " + file_path); + } + SRV_INF("loading image from local file '%s'\n", (media_path + file_path).c_str()); + std::ifstream file(media_path + file_path, std::ios::binary); + if (!file) { + throw std::invalid_argument("file does not exist or cannot be opened: " + file_path); + } + data.assign((std::istreambuf_iterator(file)), std::istreambuf_iterator()); + out_files.push_back(data); + + } else { + // try to decode base64 image + std::vector parts = string_split(url, /*separator*/ ','); + if (parts.size() != 2) { + throw std::runtime_error("Invalid url value"); + } else if (!string_starts_with(parts[0], "data:image/")) { + throw std::runtime_error("Invalid url format: " + parts[0]); + } else if (!string_ends_with(parts[0], "base64")) { + throw std::runtime_error("url must be base64 encoded"); + } else { + auto base64_data = parts[1]; + auto decoded_data = base64_decode(base64_data); + out_files.push_back(decoded_data); + } + } +} + // used by /chat/completions endpoint json oaicompat_chat_params_parse( json & body, /* openai api json semantics */ @@ -819,26 +879,26 @@ json oaicompat_chat_params_parse( auto schema_wrapper = json_value(response_format, "json_schema", json::object()); json_schema = json_value(schema_wrapper, "schema", json::object()); } else if (!response_type.empty() && response_type != "text") { - throw std::runtime_error("response_format type must be one of \"text\" or \"json_object\", but got: " + response_type); + throw std::invalid_argument("response_format type must be one of \"text\" or \"json_object\", but got: " + response_type); } } // get input files if (!body.contains("messages")) { - throw std::runtime_error("'messages' is required"); + throw std::invalid_argument("'messages' is required"); } json & messages = body.at("messages"); if (!messages.is_array()) { - throw std::runtime_error("Expected 'messages' to be an array"); + throw std::invalid_argument("Expected 'messages' to be an array"); } for (auto & msg : messages) { std::string role = json_value(msg, "role", std::string()); if (role != "assistant" && !msg.contains("content")) { - throw std::runtime_error("All non-assistant messages must contain 'content'"); + throw std::invalid_argument("All non-assistant messages must contain 'content'"); } if (role == "assistant") { if (!msg.contains("content") && !msg.contains("tool_calls")) { - throw std::runtime_error("Assistant message must contain either 'content' or 'tool_calls'!"); + throw std::invalid_argument("Assistant message must contain either 'content' or 'tool_calls'!"); } if (!msg.contains("content")) { continue; // avoid errors with no content @@ -850,7 +910,7 @@ json oaicompat_chat_params_parse( } if (!content.is_array()) { - throw std::runtime_error("Expected 'content' to be a string or an array"); + throw std::invalid_argument("Expected 'content' to be a string or an array"); } for (auto & p : content) { @@ -860,41 +920,8 @@ json oaicompat_chat_params_parse( throw std::runtime_error("image input is not supported - hint: if this is unexpected, you may need to provide the mmproj"); } - json image_url = json_value(p, "image_url", json::object()); - std::string url = json_value(image_url, "url", std::string()); - if (string_starts_with(url, "http")) { - // download remote image - // TODO @ngxson : maybe make these params configurable - common_remote_params params; - params.headers.push_back("User-Agent: llama.cpp/" + build_info); - params.max_size = 1024 * 1024 * 10; // 10MB - params.timeout = 10; // seconds - SRV_INF("downloading image from '%s'\n", url.c_str()); - auto res = common_remote_get_content(url, params); - if (200 <= res.first && res.first < 300) { - SRV_INF("downloaded %ld bytes\n", res.second.size()); - raw_buffer data; - data.insert(data.end(), res.second.begin(), res.second.end()); - out_files.push_back(data); - } else { - throw std::runtime_error("Failed to download image"); - } - - } else { - // try to decode base64 image - std::vector parts = string_split(url, /*separator*/ ','); - if (parts.size() != 2) { - throw std::runtime_error("Invalid image_url.url value"); - } else if (!string_starts_with(parts[0], "data:image/")) { - throw std::runtime_error("Invalid image_url.url format: " + parts[0]); - } else if (!string_ends_with(parts[0], "base64")) { - throw std::runtime_error("image_url.url must be base64 encoded"); - } else { - auto base64_data = parts[1]; - auto decoded_data = base64_decode(base64_data); - out_files.push_back(decoded_data); - } - } + json image_url = json_value(p, "image_url", json::object()); + handle_media(out_files, image_url, opt.media_path); // replace this chunk with a marker p["type"] = "text"; @@ -911,18 +938,20 @@ json oaicompat_chat_params_parse( std::string format = json_value(input_audio, "format", std::string()); // while we also support flac, we don't allow it here so we matches the OAI spec if (format != "wav" && format != "mp3") { - throw std::runtime_error("input_audio.format must be either 'wav' or 'mp3'"); + throw std::invalid_argument("input_audio.format must be either 'wav' or 'mp3'"); } auto decoded_data = base64_decode(data); // expected to be base64 encoded out_files.push_back(decoded_data); + // TODO: add audio_url support by reusing handle_media() + // replace this chunk with a marker p["type"] = "text"; p["text"] = mtmd_default_marker(); p.erase("input_audio"); } else if (type != "text") { - throw std::runtime_error("unsupported content[].type"); + throw std::invalid_argument("unsupported content[].type"); } } } @@ -940,7 +969,7 @@ json oaicompat_chat_params_parse( inputs.enable_thinking = opt.enable_thinking; if (!inputs.tools.empty() && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE) { if (body.contains("grammar")) { - throw std::runtime_error("Cannot use custom grammar constraints with tools."); + throw std::invalid_argument("Cannot use custom grammar constraints with tools."); } llama_params["parse_tool_calls"] = true; } @@ -959,7 +988,7 @@ json oaicompat_chat_params_parse( } else if (enable_thinking_kwarg == "false") { inputs.enable_thinking = false; } else if (!enable_thinking_kwarg.empty() && enable_thinking_kwarg[0] == '"') { - throw std::runtime_error("invalid type for \"enable_thinking\" (expected boolean, got string)"); + throw std::invalid_argument("invalid type for \"enable_thinking\" (expected boolean, got string)"); } // if the assistant message appears at the end of list, we do not add end-of-turn token @@ -972,14 +1001,14 @@ json oaicompat_chat_params_parse( /* sanity check, max one assistant message at the end of the list */ if (!inputs.messages.empty() && inputs.messages.back().role == "assistant"){ - throw std::runtime_error("Cannot have 2 or more assistant messages at the end of the list."); + throw std::invalid_argument("Cannot have 2 or more assistant messages at the end of the list."); } /* TODO: test this properly */ inputs.reasoning_format = COMMON_REASONING_FORMAT_NONE; if ( inputs.enable_thinking ) { - throw std::runtime_error("Assistant response prefill is incompatible with enable_thinking."); + throw std::invalid_argument("Assistant response prefill is incompatible with enable_thinking."); } inputs.add_generation_prompt = true; @@ -1016,22 +1045,25 @@ json oaicompat_chat_params_parse( for (const auto & stop : chat_params.additional_stops) { llama_params["stop"].push_back(stop); } + if (!chat_params.parser.empty()) { + llama_params["chat_parser"] = chat_params.parser; + } // Handle "n" field int n_choices = json_value(body, "n", 1); if (n_choices != 1) { - throw std::runtime_error("Only one completion choice is allowed"); + throw std::invalid_argument("Only one completion choice is allowed"); } // Handle "logprobs" field // TODO: The response format of this option is not yet OAI-compatible, but seems like no one really using it; We may need to fix it in the future if (json_value(body, "logprobs", false)) { if (has_tools && stream) { - throw std::runtime_error("logprobs is not supported with tools + stream"); + throw std::invalid_argument("logprobs is not supported with tools + stream"); } llama_params["n_probs"] = json_value(body, "top_logprobs", 20); } else if (body.contains("top_logprobs") && !body.at("top_logprobs").is_null()) { - throw std::runtime_error("top_logprobs requires logprobs to be set to true"); + throw std::invalid_argument("top_logprobs requires logprobs to be set to true"); } // Copy remaining properties to llama_params diff --git a/tools/server/server-common.h b/tools/server/server-common.h index 51ae9ea8a96..bb04e82b4f5 100644 --- a/tools/server/server-common.h +++ b/tools/server/server-common.h @@ -284,6 +284,7 @@ struct oaicompat_parser_options { bool allow_image; bool allow_audio; bool enable_thinking = true; + std::string media_path; }; // used by /chat/completions endpoint diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index aac1a70bb2b..c9245745756 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -788,6 +788,7 @@ struct server_context_impl { /* allow_image */ mctx ? mtmd_support_vision(mctx) : false, /* allow_audio */ mctx ? mtmd_support_audio (mctx) : false, /* enable_thinking */ enable_thinking, + /* media_path */ params_base.media_path, }; // print sample chat example to make it clear which template is used diff --git a/tools/server/server-models.cpp b/tools/server/server-models.cpp index ac7f6b86bf8..c1fbaf4ec91 100644 --- a/tools/server/server-models.cpp +++ b/tools/server/server-models.cpp @@ -7,6 +7,7 @@ #include #include +#include #include #include #include @@ -889,6 +890,28 @@ struct pipe_t { } }; +static std::string to_lower_copy(const std::string & value) { + std::string lowered(value.size(), '\0'); + std::transform(value.begin(), value.end(), lowered.begin(), [](unsigned char c) { return std::tolower(c); }); + return lowered; +} + +static bool should_strip_proxy_header(const std::string & header_name) { + // Headers that get duplicated when router forwards child responses + if (header_name == "server" || + header_name == "transfer-encoding" || + header_name == "keep-alive") { + return true; + } + + // Router injects CORS, child also sends them: duplicate + if (header_name.rfind("access-control-", 0) == 0) { + return true; + } + + return false; +} + server_http_proxy::server_http_proxy( const std::string & method, const std::string & host, @@ -925,6 +948,14 @@ server_http_proxy::server_http_proxy( msg_t msg; msg.status = response.status; for (const auto & [key, value] : response.headers) { + const auto lowered = to_lower_copy(key); + if (should_strip_proxy_header(lowered)) { + continue; + } + if (lowered == "content-type") { + msg.content_type = value; + continue; + } msg.headers[key] = value; } return pipe->write(std::move(msg)); // send headers first @@ -932,7 +963,7 @@ server_http_proxy::server_http_proxy( httplib::ContentReceiverWithProgress content_receiver = [pipe](const char * data, size_t data_length, size_t, size_t) { // send data chunks // returns false if pipe is closed / broken (signal to stop receiving) - return pipe->write({{}, 0, std::string(data, data_length)}); + return pipe->write({{}, 0, std::string(data, data_length), ""}); }; // prepare the request to destination server @@ -955,8 +986,8 @@ server_http_proxy::server_http_proxy( if (result.error() != httplib::Error::Success) { auto err_str = httplib::to_string(result.error()); SRV_ERR("http client error: %s\n", err_str.c_str()); - pipe->write({{}, 500, ""}); // header - pipe->write({{}, 0, "proxy error: " + err_str}); // body + pipe->write({{}, 500, "", ""}); // header + pipe->write({{}, 0, "proxy error: " + err_str, ""}); // body } pipe->close_write(); // signal EOF to reader SRV_DBG("%s", "client request thread ended\n"); @@ -964,12 +995,17 @@ server_http_proxy::server_http_proxy( this->thread.detach(); // wait for the first chunk (headers) - msg_t header; - if (pipe->read(header, should_stop)) { - SRV_DBG("%s", "received response headers\n"); - this->status = header.status; - this->headers = header.headers; - } else { - SRV_DBG("%s", "no response headers received (request cancelled?)\n"); + { + msg_t header; + if (pipe->read(header, should_stop)) { + SRV_DBG("%s", "received response headers\n"); + this->status = header.status; + this->headers = std::move(header.headers); + if (!header.content_type.empty()) { + this->content_type = std::move(header.content_type); + } + } else { + SRV_DBG("%s", "no response headers received (request cancelled?)\n"); + } } } diff --git a/tools/server/server-models.h b/tools/server/server-models.h index b9bec983ef6..526e7488dc9 100644 --- a/tools/server/server-models.h +++ b/tools/server/server-models.h @@ -170,5 +170,6 @@ struct server_http_proxy : server_http_res { std::map headers; int status = 0; std::string data; + std::string content_type; }; }; diff --git a/tools/server/server-task.cpp b/tools/server/server-task.cpp index 3f59127fb2f..8a9477d7321 100644 --- a/tools/server/server-task.cpp +++ b/tools/server/server-task.cpp @@ -297,6 +297,9 @@ task_params server_task::params_from_json_cmpl( params.oaicompat_chat_syntax.reasoning_in_content = params.stream && (reasoning_format == COMMON_REASONING_FORMAT_DEEPSEEK_LEGACY); params.oaicompat_chat_syntax.thinking_forced_open = json_value(data, "thinking_forced_open", false); params.oaicompat_chat_syntax.parse_tool_calls = json_value(data, "parse_tool_calls", false); + if (data.contains("chat_parser")) { + params.oaicompat_chat_syntax.parser.load(data.at("chat_parser").get()); + } } { diff --git a/tools/server/server.cpp b/tools/server/server.cpp index 950537d82d0..d5bef3df445 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -34,18 +34,26 @@ static inline void signal_handler(int signal) { static server_http_context::handler_t ex_wrapper(server_http_context::handler_t func) { return [func = std::move(func)](const server_http_req & req) -> server_http_res_ptr { std::string message; + error_type error; try { return func(req); + } catch (const std::invalid_argument & e) { + // treat invalid_argument as invalid request (400) + error = ERROR_TYPE_INVALID_REQUEST; + message = e.what(); } catch (const std::exception & e) { + // treat other exceptions as server error (500) + error = ERROR_TYPE_SERVER; message = e.what(); } catch (...) { + error = ERROR_TYPE_SERVER; message = "unknown error"; } auto res = std::make_unique(); res->status = 500; try { - json error_data = format_error_response(message, ERROR_TYPE_SERVER); + json error_data = format_error_response(message, error); res->status = json_value(error_data, "code", 500); res->data = safe_json_to_str({{ "error", error_data }}); SRV_WRN("got exception: %s\n", res->data.c_str()); diff --git a/tools/server/tests/unit/test_basic.py b/tools/server/tests/unit/test_basic.py index cadaa91849f..3405be3e25d 100644 --- a/tools/server/tests/unit/test_basic.py +++ b/tools/server/tests/unit/test_basic.py @@ -65,6 +65,7 @@ def test_server_slots(): def test_load_split_model(): global server + server.offline = False server.model_hf_repo = "ggml-org/models" server.model_hf_file = "tinyllamas/split/stories15M-q8_0-00001-of-00003.gguf" server.model_alias = "tinyllama-split" diff --git a/tools/server/tests/unit/test_chat_completion.py b/tools/server/tests/unit/test_chat_completion.py index 093cec91555..aa6229c93a5 100644 --- a/tools/server/tests/unit/test_chat_completion.py +++ b/tools/server/tests/unit/test_chat_completion.py @@ -199,7 +199,7 @@ def test_completion_with_response_format(response_format: dict, n_predicted: int choice = res.body["choices"][0] assert match_regex(re_content, choice["message"]["content"]) else: - assert res.status_code != 200 + assert res.status_code == 400 assert "error" in res.body diff --git a/tools/server/tests/unit/test_router.py b/tools/server/tests/unit/test_router.py index e6f3c6485c0..e85f2c33829 100644 --- a/tools/server/tests/unit/test_router.py +++ b/tools/server/tests/unit/test_router.py @@ -17,7 +17,6 @@ def create_server(): ] ) def test_router_chat_completion_stream(model: str, success: bool): - # TODO: make sure the model is in cache (ie. ServerProcess.load_all()) before starting the router server global server server.start() content = "" @@ -48,3 +47,148 @@ def test_router_chat_completion_stream(model: str, success: bool): else: assert ex is not None assert content == "" + + +def _get_model_status(model_id: str) -> str: + res = server.make_request("GET", "/models") + assert res.status_code == 200 + for item in res.body.get("data", []): + if item.get("id") == model_id or item.get("model") == model_id: + return item["status"]["value"] + raise AssertionError(f"Model {model_id} not found in /models response") + + +def _wait_for_model_status(model_id: str, desired: set[str], timeout: int = 60) -> str: + deadline = time.time() + timeout + last_status = None + while time.time() < deadline: + last_status = _get_model_status(model_id) + if last_status in desired: + return last_status + time.sleep(1) + raise AssertionError( + f"Timed out waiting for {model_id} to reach {desired}, last status: {last_status}" + ) + + +def _load_model_and_wait( + model_id: str, timeout: int = 60, headers: dict | None = None +) -> None: + load_res = server.make_request( + "POST", "/models/load", data={"model": model_id}, headers=headers + ) + assert load_res.status_code == 200 + assert isinstance(load_res.body, dict) + assert load_res.body.get("success") is True + _wait_for_model_status(model_id, {"loaded"}, timeout=timeout) + + +def test_router_unload_model(): + global server + server.start() + model_id = "ggml-org/tinygemma3-GGUF:Q8_0" + + _load_model_and_wait(model_id) + + unload_res = server.make_request("POST", "/models/unload", data={"model": model_id}) + assert unload_res.status_code == 200 + assert unload_res.body.get("success") is True + _wait_for_model_status(model_id, {"unloaded"}) + + +def test_router_models_max_evicts_lru(): + global server + server.models_max = 2 + server.start() + + candidate_models = [ + "ggml-org/tinygemma3-GGUF:Q8_0", + "ggml-org/test-model-stories260K", + "ggml-org/test-model-stories260K-infill", + ] + + # Load only the first 2 models to fill the cache + first, second, third = candidate_models[:3] + + _load_model_and_wait(first, timeout=120) + _load_model_and_wait(second, timeout=120) + + # Verify both models are loaded + assert _get_model_status(first) == "loaded" + assert _get_model_status(second) == "loaded" + + # Load the third model - this should trigger LRU eviction of the first model + _load_model_and_wait(third, timeout=120) + + # Verify eviction: third is loaded, first was evicted + assert _get_model_status(third) == "loaded" + assert _get_model_status(first) == "unloaded" + + +def test_router_no_models_autoload(): + global server + server.no_models_autoload = True + server.start() + model_id = "ggml-org/tinygemma3-GGUF:Q8_0" + + res = server.make_request( + "POST", + "/v1/chat/completions", + data={ + "model": model_id, + "messages": [{"role": "user", "content": "hello"}], + "max_tokens": 4, + }, + ) + assert res.status_code == 400 + assert "error" in res.body + + _load_model_and_wait(model_id) + + success_res = server.make_request( + "POST", + "/v1/chat/completions", + data={ + "model": model_id, + "messages": [{"role": "user", "content": "hello"}], + "max_tokens": 4, + }, + ) + assert success_res.status_code == 200 + assert "error" not in success_res.body + + +def test_router_api_key_required(): + global server + server.api_key = "sk-router-secret" + server.start() + + model_id = "ggml-org/tinygemma3-GGUF:Q8_0" + auth_headers = {"Authorization": f"Bearer {server.api_key}"} + + res = server.make_request( + "POST", + "/v1/chat/completions", + data={ + "model": model_id, + "messages": [{"role": "user", "content": "hello"}], + "max_tokens": 4, + }, + ) + assert res.status_code == 401 + assert res.body.get("error", {}).get("type") == "authentication_error" + + _load_model_and_wait(model_id, headers=auth_headers) + + authed = server.make_request( + "POST", + "/v1/chat/completions", + headers=auth_headers, + data={ + "model": model_id, + "messages": [{"role": "user", "content": "hello"}], + "max_tokens": 4, + }, + ) + assert authed.status_code == 200 + assert "error" not in authed.body diff --git a/tools/server/tests/unit/test_security.py b/tools/server/tests/unit/test_security.py index e160a8e6d30..8c38b89d535 100644 --- a/tools/server/tests/unit/test_security.py +++ b/tools/server/tests/unit/test_security.py @@ -94,3 +94,34 @@ def test_cors_options(origin: str, cors_header: str, cors_header_value: str): assert res.status_code == 200 assert cors_header in res.headers assert res.headers[cors_header] == cors_header_value + + +@pytest.mark.parametrize( + "media_path, image_url, success", + [ + (None, "file://mtmd/test-1.jpeg", False), # disabled media path, should fail + ("../../../tools", "file://mtmd/test-1.jpeg", True), + ("../../../tools", "file:////mtmd//test-1.jpeg", True), # should be the same file as above + ("../../../tools", "file://mtmd/notfound.jpeg", False), # non-existent file + ("../../../tools", "file://../mtmd/test-1.jpeg", False), # no directory traversal + ] +) +def test_local_media_file(media_path, image_url, success,): + server = ServerPreset.tinygemma3() + server.media_path = media_path + server.start() + res = server.make_request("POST", "/chat/completions", data={ + "max_tokens": 1, + "messages": [ + {"role": "user", "content": [ + {"type": "text", "text": "test"}, + {"type": "image_url", "image_url": { + "url": image_url, + }}, + ]}, + ], + }) + if success: + assert res.status_code == 200 + else: + assert res.status_code == 400 diff --git a/tools/server/tests/utils.py b/tools/server/tests/utils.py index afe4f77d978..48e7403602f 100644 --- a/tools/server/tests/utils.py +++ b/tools/server/tests/utils.py @@ -7,6 +7,7 @@ import os import re import json +from json import JSONDecodeError import sys import requests import time @@ -83,6 +84,9 @@ class ServerProcess: pooling: str | None = None draft: int | None = None api_key: str | None = None + models_dir: str | None = None + models_max: int | None = None + no_models_autoload: bool | None = None lora_files: List[str] | None = None enable_ctx_shift: int | None = False draft_min: int | None = None @@ -95,6 +99,7 @@ class ServerProcess: chat_template_file: str | None = None server_path: str | None = None mmproj_url: str | None = None + media_path: str | None = None # session variables process: subprocess.Popen | None = None @@ -142,6 +147,10 @@ def start(self, timeout_seconds: int | None = DEFAULT_HTTP_TIMEOUT) -> None: server_args.extend(["--hf-repo", self.model_hf_repo]) if self.model_hf_file: server_args.extend(["--hf-file", self.model_hf_file]) + if self.models_dir: + server_args.extend(["--models-dir", self.models_dir]) + if self.models_max is not None: + server_args.extend(["--models-max", self.models_max]) if self.n_batch: server_args.extend(["--batch-size", self.n_batch]) if self.n_ubatch: @@ -203,6 +212,8 @@ def start(self, timeout_seconds: int | None = DEFAULT_HTTP_TIMEOUT) -> None: server_args.extend(["--draft-min", self.draft_min]) if self.no_webui: server_args.append("--no-webui") + if self.no_models_autoload: + server_args.append("--no-models-autoload") if self.jinja: server_args.append("--jinja") else: @@ -217,6 +228,8 @@ def start(self, timeout_seconds: int | None = DEFAULT_HTTP_TIMEOUT) -> None: server_args.extend(["--chat-template-file", self.chat_template_file]) if self.mmproj_url: server_args.extend(["--mmproj-url", self.mmproj_url]) + if self.media_path: + server_args.extend(["--media-path", self.media_path]) args = [str(arg) for arg in [server_path, *server_args]] print(f"tests: starting server with: {' '.join(args)}") @@ -292,7 +305,13 @@ def make_request( result = ServerResponse() result.headers = dict(response.headers) result.status_code = response.status_code - result.body = response.json() if parse_body else None + if parse_body: + try: + result.body = response.json() + except JSONDecodeError: + result.body = response.text + else: + result.body = None print("Response from server", json.dumps(result.body, indent=2)) return result @@ -431,8 +450,9 @@ def load_all() -> None: @staticmethod def tinyllama2() -> ServerProcess: server = ServerProcess() - server.model_hf_repo = "ggml-org/models" - server.model_hf_file = "tinyllamas/stories260K.gguf" + server.offline = True # will be downloaded by load_all() + server.model_hf_repo = "ggml-org/test-model-stories260K" + server.model_hf_file = None server.model_alias = "tinyllama-2" server.n_ctx = 512 server.n_batch = 32 @@ -476,8 +496,8 @@ def bert_bge_small_with_fa() -> ServerProcess: def tinyllama_infill() -> ServerProcess: server = ServerProcess() server.offline = True # will be downloaded by load_all() - server.model_hf_repo = "ggml-org/models" - server.model_hf_file = "tinyllamas/stories260K-infill.gguf" + server.model_hf_repo = "ggml-org/test-model-stories260K-infill" + server.model_hf_file = None server.model_alias = "tinyllama-infill" server.n_ctx = 2048 server.n_batch = 1024 @@ -534,6 +554,7 @@ def tinygemma3() -> ServerProcess: @staticmethod def router() -> ServerProcess: server = ServerProcess() + server.offline = True # will be downloaded by load_all() # router server has no models server.model_file = None server.model_alias = None diff --git a/tools/server/webui/docs/flows/data-flow-simplified-router-mode.md b/tools/server/webui/docs/flows/data-flow-simplified-router-mode.md index f5c4f05edf6..bccacf56841 100644 --- a/tools/server/webui/docs/flows/data-flow-simplified-router-mode.md +++ b/tools/server/webui/docs/flows/data-flow-simplified-router-mode.md @@ -15,7 +15,7 @@ sequenceDiagram Stores->>DB: load conversations Stores->>API: GET /props API-->>Stores: {role: "router"} - Stores->>API: GET /models + Stores->>API: GET /v1/models API-->>Stores: models[] with status (loaded/available) loop each loaded model Stores->>API: GET /props?model=X @@ -28,7 +28,7 @@ sequenceDiagram alt model not loaded Stores->>API: POST /models/load loop poll status - Stores->>API: GET /models + Stores->>API: GET /v1/models API-->>Stores: check if loaded end Stores->>API: GET /props?model=X diff --git a/tools/server/webui/docs/flows/models-flow.md b/tools/server/webui/docs/flows/models-flow.md index ce63da1b367..c3031b72923 100644 --- a/tools/server/webui/docs/flows/models-flow.md +++ b/tools/server/webui/docs/flows/models-flow.md @@ -56,7 +56,7 @@ sequenceDiagram UI->>modelsStore: fetchRouterModels() activate modelsStore modelsStore->>ModelsSvc: listRouter() - ModelsSvc->>API: GET /models + ModelsSvc->>API: GET /v1/models API-->>ModelsSvc: ApiRouterModelsListResponse Note right of API: {data: [{id, status, path, in_cache}]} modelsStore->>modelsStore: routerModels = $state(data) @@ -132,7 +132,7 @@ sequenceDiagram loop poll every 500ms (max 60 attempts) modelsStore->>modelsStore: fetchRouterModels() modelsStore->>ModelsSvc: listRouter() - ModelsSvc->>API: GET /models + ModelsSvc->>API: GET /v1/models API-->>ModelsSvc: models[] modelsStore->>modelsStore: getModelStatus(modelId) alt status === LOADED @@ -165,7 +165,7 @@ sequenceDiagram modelsStore->>modelsStore: pollForModelStatus(modelId, UNLOADED) loop poll until unloaded modelsStore->>ModelsSvc: listRouter() - ModelsSvc->>API: GET /models + ModelsSvc->>API: GET /v1/models end modelsStore->>modelsStore: modelLoadingStates.set(modelId, false) diff --git a/tools/server/webui/src/lib/components/app/chat/ChatForm/ChatForm.svelte b/tools/server/webui/src/lib/components/app/chat/ChatForm/ChatForm.svelte index 97dccd8be8f..7f8e38286d2 100644 --- a/tools/server/webui/src/lib/components/app/chat/ChatForm/ChatForm.svelte +++ b/tools/server/webui/src/lib/components/app/chat/ChatForm/ChatForm.svelte @@ -64,7 +64,10 @@ let fileInputRef: ChatFormFileInputInvisible | undefined = $state(undefined); let isRecording = $state(false); let message = $state(''); - let pasteLongTextToFileLength = $derived(Number(currentConfig.pasteLongTextToFileLen) || 2500); + let pasteLongTextToFileLength = $derived.by(() => { + const n = Number(currentConfig.pasteLongTextToFileLen); + return Number.isNaN(n) ? 2500 : n; + }); let previousIsLoading = $state(isLoading); let recordingSupported = $state(false); let textareaRef: ChatFormTextarea | undefined = $state(undefined); diff --git a/tools/server/webui/src/lib/components/app/chat/ChatMessages/ChatMessages.svelte b/tools/server/webui/src/lib/components/app/chat/ChatMessages/ChatMessages.svelte index 6024f66c8bd..f307f829bc6 100644 --- a/tools/server/webui/src/lib/components/app/chat/ChatMessages/ChatMessages.svelte +++ b/tools/server/webui/src/lib/components/app/chat/ChatMessages/ChatMessages.svelte @@ -1,6 +1,5 @@ @@ -70,8 +55,8 @@
Loading model information...
- {:else if modelsData && modelsData.data.length > 0} - {@const modelMeta = modelsData.data[0].meta} + {:else if firstModel} + {@const modelMeta = firstModel.meta} {#if serverProps} diff --git a/tools/server/webui/src/lib/constants/supported-file-types.ts b/tools/server/webui/src/lib/constants/supported-file-types.ts index 93bbab5d399..0d955ad1479 100644 --- a/tools/server/webui/src/lib/constants/supported-file-types.ts +++ b/tools/server/webui/src/lib/constants/supported-file-types.ts @@ -126,8 +126,13 @@ export const TEXT_FILE_TYPES = { mimeTypes: [MimeTypeText.JAVA] }, [FileTypeText.CPP]: { - extensions: [FileExtensionText.CPP, FileExtensionText.C, FileExtensionText.H], - mimeTypes: [MimeTypeText.CPP_SRC, MimeTypeText.C_SRC, MimeTypeText.C_HDR] + extensions: [ + FileExtensionText.CPP, + FileExtensionText.C, + FileExtensionText.H, + FileExtensionText.HPP + ], + mimeTypes: [MimeTypeText.CPP_SRC, MimeTypeText.CPP_HDR, MimeTypeText.C_SRC, MimeTypeText.C_HDR] }, [FileTypeText.PHP]: { extensions: [FileExtensionText.PHP], @@ -183,10 +188,30 @@ export const TEXT_FILE_TYPES = { }, [FileTypeText.LATEX]: { extensions: [FileExtensionText.TEX], - mimeTypes: [MimeTypeText.LATEX] + mimeTypes: [MimeTypeText.LATEX, MimeTypeText.TEX, MimeTypeText.TEX_APP] }, [FileTypeText.BIBTEX]: { extensions: [FileExtensionText.BIB], mimeTypes: [MimeTypeText.BIBTEX] + }, + [FileTypeText.CUDA]: { + extensions: [FileExtensionText.CU, FileExtensionText.CUH], + mimeTypes: [MimeTypeText.CUDA] + }, + [FileTypeText.VULKAN]: { + extensions: [FileExtensionText.COMP], + mimeTypes: [MimeTypeText.PLAIN] + }, + [FileTypeText.HASKELL]: { + extensions: [FileExtensionText.HS], + mimeTypes: [MimeTypeText.HASKELL] + }, + [FileTypeText.CSHARP]: { + extensions: [FileExtensionText.CS], + mimeTypes: [MimeTypeText.CSHARP] + }, + [FileTypeText.PROPERTIES]: { + extensions: [FileExtensionText.PROPERTIES], + mimeTypes: [MimeTypeText.PROPERTIES] } } as const; diff --git a/tools/server/webui/src/lib/enums/files.ts b/tools/server/webui/src/lib/enums/files.ts index 45b0feea169..a4f079d405f 100644 --- a/tools/server/webui/src/lib/enums/files.ts +++ b/tools/server/webui/src/lib/enums/files.ts @@ -62,7 +62,12 @@ export enum FileTypeText { VUE = 'vue', SVELTE = 'svelte', LATEX = 'latex', - BIBTEX = 'bibtex' + BIBTEX = 'bibtex', + CUDA = 'cuda', + VULKAN = 'vulkan', + HASKELL = 'haskell', + CSHARP = 'csharp', + PROPERTIES = 'properties' } // File extension enums @@ -121,7 +126,14 @@ export enum FileExtensionText { VUE = '.vue', SVELTE = '.svelte', TEX = '.tex', - BIB = '.bib' + BIB = '.bib', + CU = '.cu', + CUH = '.cuh', + COMP = '.comp', + HPP = '.hpp', + HS = '.hs', + PROPERTIES = '.properties', + CS = '.cs' } // MIME type enums @@ -165,7 +177,10 @@ export enum MimeTypeText { CSV = 'text/csv', PYTHON = 'text/x-python', JAVA = 'text/x-java-source', + CPP_HDR = 'text/x-c++hdr', CPP_SRC = 'text/x-c++src', + CSHARP = 'text/x-csharp', + HASKELL = 'text/x-haskell', C_SRC = 'text/x-csrc', C_HDR = 'text/x-chdr', PHP = 'text/x-php', @@ -182,6 +197,10 @@ export enum MimeTypeText { DART = 'text/x-dart', VUE = 'text/x-vue', SVELTE = 'text/x-svelte', - LATEX = 'text/x-tex', - BIBTEX = 'text/x-bibtex' + TEX = 'text/x-tex', + TEX_APP = 'application/x-tex', + LATEX = 'application/x-latex', + BIBTEX = 'text/x-bibtex', + CUDA = 'text/x-cuda', + PROPERTIES = 'text/properties' } diff --git a/tools/server/webui/src/lib/services/chat.ts b/tools/server/webui/src/lib/services/chat.ts index 70b18c8a00c..a6a68124035 100644 --- a/tools/server/webui/src/lib/services/chat.ts +++ b/tools/server/webui/src/lib/services/chat.ts @@ -677,48 +677,6 @@ export class ChatService { // Utilities // ───────────────────────────────────────────────────────────────────────────── - /** - * Get server properties - static method for API compatibility (to be refactored) - */ - static async getServerProps(): Promise { - try { - const response = await fetch(`./props`, { - headers: getJsonHeaders() - }); - - if (!response.ok) { - throw new Error(`Failed to fetch server props: ${response.status}`); - } - - const data = await response.json(); - return data; - } catch (error) { - console.error('Error fetching server props:', error); - throw error; - } - } - - /** - * Get model information from /models endpoint (to be refactored) - */ - static async getModels(): Promise { - try { - const response = await fetch(`./models`, { - headers: getJsonHeaders() - }); - - if (!response.ok) { - throw new Error(`Failed to fetch models: ${response.status} ${response.statusText}`); - } - - const data = await response.json(); - return data; - } catch (error) { - console.error('Error fetching models:', error); - throw error; - } - } - /** * Injects a system message at the beginning of the conversation if provided. * Checks for existing system messages to avoid duplication. diff --git a/tools/server/webui/src/lib/services/models.ts b/tools/server/webui/src/lib/services/models.ts index f031bd74975..eecb7fa2628 100644 --- a/tools/server/webui/src/lib/services/models.ts +++ b/tools/server/webui/src/lib/services/models.ts @@ -7,7 +7,7 @@ import { getJsonHeaders } from '$lib/utils'; * * This service handles communication with model-related endpoints: * - `/v1/models` - OpenAI-compatible model list (MODEL + ROUTER mode) - * - `/models` - Router-specific model management (ROUTER mode only) + * - `/models/load`, `/models/unload` - Router-specific model management (ROUTER mode only) * * **Responsibilities:** * - List available models @@ -43,7 +43,7 @@ export class ModelsService { * Returns models with load status, paths, and other metadata */ static async listRouter(): Promise { - const response = await fetch(`${base}/models`, { + const response = await fetch(`${base}/v1/models`, { headers: getJsonHeaders() }); diff --git a/tools/server/webui/src/lib/stores/conversations.svelte.ts b/tools/server/webui/src/lib/stores/conversations.svelte.ts index 44ef36d6ee5..f766561971c 100644 --- a/tools/server/webui/src/lib/stores/conversations.svelte.ts +++ b/tools/server/webui/src/lib/stores/conversations.svelte.ts @@ -519,6 +519,19 @@ class ConversationsStore { return await DatabaseService.getConversationMessages(convId); } + /** + * Imports conversations from provided data (without file picker) + * @param data - Array of conversation data with messages + * @returns Import result with counts + */ + async importConversationsData( + data: ExportedConversations + ): Promise<{ imported: number; skipped: number }> { + const result = await DatabaseService.importConversations(data); + await this.loadConversations(); + return result; + } + /** * Adds a message to the active messages array * Used by chatStore when creating new messages diff --git a/tools/server/webui/src/lib/stores/settings.svelte.ts b/tools/server/webui/src/lib/stores/settings.svelte.ts index 5140995eea4..2b7d8db1021 100644 --- a/tools/server/webui/src/lib/stores/settings.svelte.ts +++ b/tools/server/webui/src/lib/stores/settings.svelte.ts @@ -370,6 +370,10 @@ class SettingsStore { return { ...this.config }; } + canSyncParameter(key: string): boolean { + return ParameterSyncService.canSyncParameter(key); + } + /** * Get parameter information including source for a specific parameter */ diff --git a/tools/server/webui/src/lib/utils/file-type.ts b/tools/server/webui/src/lib/utils/file-type.ts index f096b463d40..ff7ed6b0c98 100644 --- a/tools/server/webui/src/lib/utils/file-type.ts +++ b/tools/server/webui/src/lib/utils/file-type.ts @@ -77,6 +77,13 @@ export function getFileTypeCategory(mimeType: string): FileTypeCategory | null { case MimeTypeText.SVELTE: case MimeTypeText.LATEX: case MimeTypeText.BIBTEX: + case MimeTypeText.CUDA: + case MimeTypeText.CPP_HDR: + case MimeTypeText.CSHARP: + case MimeTypeText.HASKELL: + case MimeTypeText.PROPERTIES: + case MimeTypeText.TEX: + case MimeTypeText.TEX_APP: return FileTypeCategory.TEXT; default: @@ -144,6 +151,12 @@ export function getFileTypeCategoryByExtension(filename: string): FileTypeCatego case FileExtensionText.SVELTE: case FileExtensionText.TEX: case FileExtensionText.BIB: + case FileExtensionText.COMP: + case FileExtensionText.CU: + case FileExtensionText.CUH: + case FileExtensionText.HPP: + case FileExtensionText.HS: + case FileExtensionText.PROPERTIES: return FileTypeCategory.TEXT; default: diff --git a/vendor/cpp-httplib/CMakeLists.txt b/vendor/cpp-httplib/CMakeLists.txt index 8e1cd9a9dae..369502d7aec 100644 --- a/vendor/cpp-httplib/CMakeLists.txt +++ b/vendor/cpp-httplib/CMakeLists.txt @@ -144,4 +144,7 @@ if (CPPHTTPLIB_OPENSSL_SUPPORT) find_library(SECURITY_FRAMEWORK Security REQUIRED) target_link_libraries(${TARGET} PUBLIC ${CORE_FOUNDATION_FRAMEWORK} ${SECURITY_FRAMEWORK}) endif() + if (WIN32 AND NOT MSVC) + target_link_libraries(${TARGET} PUBLIC crypt32) + endif() endif()