diff --git a/lib/safetensors.ex b/lib/safetensors.ex index 3730797..8ed2160 100644 --- a/lib/safetensors.ex +++ b/lib/safetensors.ex @@ -69,16 +69,27 @@ defmodule Safetensors do :ok end + cond do + Code.ensure_loaded?(JSON) -> + @json_module JSON + + Code.ensure_loaded?(Jason) -> + @json_module Jason + + true -> + raise "You need to include jason package in your dependencies to make safetensors work with your current Elixir (#{System.version()}) or upgrade to Elixir 1.18+" + end + defp tensor_header_entry(tensor_name, tensor, offset) do end_offset = offset + tensor_byte_size(tensor) header_entry = { tensor_name, - Jason.OrderedObject.new( + %{ dtype: tensor |> Nx.type() |> type_to_dtype(), shape: tensor |> Nx.shape() |> Tuple.to_list(), data_offsets: [offset, end_offset] - ) + } } {header_entry, end_offset} @@ -87,8 +98,8 @@ defmodule Safetensors do defp header_binary(header_entries) do header_json = header_entries - |> Jason.OrderedObject.new() - |> Jason.encode!() + |> Map.new() + |> @json_module.encode!() [<>, header_json] end @@ -212,7 +223,7 @@ defmodule Safetensors do defp decode_header!(header_json) do {_metadata, header} = header_json - |> Jason.decode!() + |> @json_module.decode!() |> Map.pop(@header_metadata_key) header diff --git a/mix.exs b/mix.exs index 5fd32dc..e2676b6 100644 --- a/mix.exs +++ b/mix.exs @@ -24,8 +24,12 @@ defmodule Safetensors.MixProject do defp deps do [ - {:jason, "~> 1.4"}, {:nx, "~> 0.5"}, + + # Optional + {:jason, "~> 1.4", optional: true}, + + # Dev {:ex_doc, "~> 0.37", only: :dev, runtime: false} ] end