From 5a37ecd4897a2ee78683c6349509f1e5f108f44d Mon Sep 17 00:00:00 2001 From: preciz Date: Sat, 12 Oct 2024 23:38:02 +0200 Subject: [PATCH 1/2] Fix: call apply/3 as intended --- lib/axon/quantization.ex | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/axon/quantization.ex b/lib/axon/quantization.ex index b48d18ffa..ed976b8d0 100644 --- a/lib/axon/quantization.ex +++ b/lib/axon/quantization.ex @@ -132,7 +132,7 @@ defmodule Axon.Quantization do fun = case opts[:kernel_initializer] do init when is_atom(init) -> - apply(Axon.Initializers, []) + apply(Axon.Initializers, init, []) fun when is_function(fun) -> fun From 3dd4fb90bc924a52c38dc8bc8d46f5266e9000a1 Mon Sep 17 00:00:00 2001 From: preciz Date: Sun, 13 Oct 2024 01:42:05 +0200 Subject: [PATCH 2/2] Add tests for Axon.Quantizaiton.weight_only_quantized_dense --- test/axon/quantization_test.exs | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/test/axon/quantization_test.exs b/test/axon/quantization_test.exs index 4a289ce03..3d728158c 100644 --- a/test/axon/quantization_test.exs +++ b/test/axon/quantization_test.exs @@ -42,4 +42,18 @@ defmodule Axon.QuantizationTest do assert_equal(predict_fn.(quantized_model_state, inp), real_fn.(quantized_model_state, inp)) end end + + describe "weight_only_quantized_dense" do + test "inits and executes properly" do + model = + Axon.input("input") + |> Axon.Quantization.weight_only_quantized_dense(10) + + assert {init_fn, _} = Axon.build(model) + assert %ModelState{} = model_state = init_fn.(Nx.template({1, 1}, :f32), ModelState.empty()) + + assert {_, predict_fn} = Axon.build(model) + assert predict_fn.(model_state, Nx.broadcast(1.0, {1, 1})) + end + end end