We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 1194331 commit 1222399Copy full SHA for 1222399
lib/tokenizers/encoding.ex
@@ -35,6 +35,19 @@ defmodule Tokenizers.Encoding do
35
@spec get_type_ids(Encoding.t()) :: [integer()]
36
def get_type_ids(encoding), do: encoding |> Native.get_type_ids() |> Shared.unwrap()
37
38
+ @doc """
39
+ Get special tokens mask from an encoding.
40
+ """
41
+ @spec get_special_tokens_mask(Encoding.t()) :: [integer()]
42
+ def get_special_tokens_mask(encoding),
43
+ do: encoding |> Native.get_special_tokens_mask() |> Shared.unwrap()
44
+
45
46
+ Get offsets from an encoding.
47
48
+ @spec get_offsets(Encoding.t()) :: [{integer(), integer()}]
49
+ def get_offsets(encoding), do: encoding |> Native.get_offsets() |> Shared.unwrap()
50
51
@doc """
52
Truncate the encoding to the given length.
53
lib/tokenizers/native.ex
@@ -20,6 +20,8 @@ defmodule Tokenizers.Native do
20
def get_type_ids(_encoding), do: err()
21
def get_ids(_encoding), do: err()
22
def get_tokens(_encoding), do: err()
23
+ def get_special_tokens_mask(_encoding), do: err()
24
+ def get_offsets(_encoding), do: err()
25
def get_vocab(_tokenizer, _with_added_tokens), do: err()
26
def get_vocab_size(_tokenizer, _with_added_tokens), do: err()
27
def id_to_token(_tokenizer, _id), do: err()
native/ex_tokenizers/src/encoding.rs
@@ -45,6 +45,20 @@ pub fn get_type_ids(encoding: ExTokenizersEncoding) -> Result<Vec<u32>, ExTokeni
Ok(encoding.resource.0.get_type_ids().to_vec())
}
+#[rustler::nif]
+pub fn get_special_tokens_mask(
+ encoding: ExTokenizersEncoding,
+) -> Result<Vec<u32>, ExTokenizersError> {
+ Ok(encoding.resource.0.get_special_tokens_mask().to_vec())
+}
54
55
56
+pub fn get_offsets(
57
58
+) -> Result<Vec<(usize, usize)>, ExTokenizersError> {
59
+ Ok(encoding.resource.0.get_offsets().to_vec())
60
61
62
#[rustler::nif]
63
pub fn n_tokens(encoding: ExTokenizersEncoding) -> Result<usize, ExTokenizersError> {
64
Ok(encoding.resource.0.len())
native/ex_tokenizers/src/error.rs
@@ -21,7 +21,7 @@ pub enum ExTokenizersError {
Unknown(#[from] anyhow::Error),
-impl<'a> Encoder for ExTokenizersError {
+impl Encoder for ExTokenizersError {
fn encode<'b>(&self, env: Env<'b>) -> Term<'b> {
format!("{:?}", self).encode(env)
native/ex_tokenizers/src/lib.rs
@@ -29,6 +29,8 @@ rustler::init!(
29
get_attention_mask,
30
get_type_ids,
31
get_ids,
32
+ get_special_tokens_mask,
33
+ get_offsets,
34
get_model,
get_model_details,
get_tokens,
test/tokenizers/tokenizer_test.exs
@@ -80,4 +80,24 @@ defmodule Tokenizers.TokenizerTest do
80
assert decoded == text
81
end
82
83
84
+ describe "encode metadata" do
85
+ test "can return special tokens mask", %{tokenizer: tokenizer} do
86
+ text = ["This is a test", "And so is this"]
87
+ {:ok, encodings} = Tokenizer.encode(tokenizer, text)
88
+ special_tokens_mask = Enum.map(encodings, &Encoding.get_special_tokens_mask/1)
89
+ assert [[1, 0, 0, 0, 0, 1], [1, 0, 0, 0, 0, 1]] == special_tokens_mask
90
+ end
91
92
+ test "can return offsets", %{tokenizer: tokenizer} do
93
94
95
+ offsets = Enum.map(encodings, &Encoding.get_offsets/1)
96
97
+ assert [
98
+ [{0, 0}, {0, 4}, {5, 7}, {8, 9}, {10, 14}, {0, 0}],
99
+ [{0, 0}, {0, 3}, {4, 6}, {7, 9}, {10, 14}, {0, 0}]
100
+ ] == offsets
101
102
103
0 commit comments