Skip to content

Commit 81338b6

Browse files
committed
Add evaluation scripts
1 parent 6e59b7a commit 81338b6

File tree

2 files changed

+130
-0
lines changed

2 files changed

+130
-0
lines changed

abx.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
import argparse
2+
import ast
3+
from pathlib import Path
4+
5+
import ABXpy.analyze as analyze
6+
import ABXpy.distances.distances as distances
7+
import ABXpy.distances.metrics.cosine as cosine
8+
import ABXpy.distances.metrics.dtw as dtw
9+
import ABXpy.score as score
10+
import pandas
11+
from ABXpy.misc.any2h5features import convert
12+
13+
14+
def dtw_cosine_distance(x, y, normalized):
15+
return dtw.dtw(x, y, cosine.cosine_distance, normalized)
16+
17+
18+
def average_abx(filename, task_type):
19+
df = pandas.read_csv(filename, sep='\t')
20+
if task_type == "across":
21+
# aggregate on context
22+
groups = df.groupby(["speaker_1", "speaker_2", "phone_1", "phone_2"],
23+
as_index=False)
24+
df = groups["score"].mean()
25+
elif task_type == "within":
26+
arr = list(map(ast.literal_eval, df["by"]))
27+
df["speaker"] = [e for e, f, g in arr]
28+
df["context"] = [f for e, f, g in arr]
29+
30+
# aggregate on context
31+
groups = df.groupby(["speaker", "phone_1", "phone_2"], as_index=False)
32+
df = groups["score"].mean()
33+
else:
34+
raise ValueError("Unknown task type: {0}".format(task_type))
35+
36+
# aggregate on talker
37+
groups = df.groupby(["phone_1", "phone_2"], as_index=False)
38+
df = groups['score'].mean()
39+
average = df.mean()[0]
40+
average = (1.0 - average) * 100
41+
return average
42+
43+
44+
def evaluate_abx(args):
45+
out_dir = Path(args.out_dir)
46+
out_dir.mkdir(parents=True, exist_ok=True)
47+
48+
feature_path = out_dir / "features.features"
49+
distance_path = out_dir / "data.distance"
50+
score_path = out_dir / "data.score"
51+
analyze_path = out_dir / "data.csv"
52+
53+
if not feature_path.exists():
54+
convert(args.feature_dir, h5_filename=str(feature_path))
55+
56+
if not distance_path.exists():
57+
distances.compute_distances(
58+
str(feature_path), "features", str(args.task_path),
59+
str(distance_path), dtw_cosine_distance,
60+
normalized=True, n_cpu=6)
61+
62+
if not score_path.exists():
63+
score.score(str(args.task_path), str(distance_path), str(score_path))
64+
65+
if not analyze_path.exists():
66+
analyze.analyze(str(args.task_path), str(score_path), str(analyze_path))
67+
68+
abx = average_abx(str(analyze_path), args.task_type)
69+
print("average abx: {:.3f}".format(abx))
70+
71+
72+
if __name__ == "__main__":
73+
parser = argparse.ArgumentParser()
74+
parser.add_argument("--task-type", type=str)
75+
parser.add_argument("--task-path", type=str)
76+
parser.add_argument("--feature-dir", type=str)
77+
parser.add_argument("--out-dir", type=str)
78+
args = parser.parse_args()
79+
evaluate_abx(args)

encode.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
import argparse
2+
from pathlib import Path
3+
import json
4+
import numpy as np
5+
import torch
6+
from model import Encoder
7+
from tqdm import tqdm
8+
9+
10+
def encode_dataset(args, params):
11+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12+
13+
model = Encoder(in_channels=params["preprocessing"]["num_mels"],
14+
encoder_channels=params["model"]["encoder_channels"],
15+
z_dim=params["model"]["z_dim"],
16+
c_dim=params["model"]["c_dim"])
17+
model.to(device)
18+
19+
print("Load checkpoint from: {}:".format(args.checkpoint))
20+
checkpoint = torch.load(args.checkpoint, map_location=lambda storage, loc: storage)
21+
model.load_state_dict(checkpoint["model"])
22+
model.eval()
23+
24+
out_dir = Path(args.out_dir)
25+
out_dir.mkdir(exist_ok=True, parents=True)
26+
27+
hop_length_seconds = params["preprocessing"]["hop_length"] / params["preprocessing"]["sample_rate"]
28+
29+
in_dir = Path(args.in_dir)
30+
for path in tqdm(in_dir.rglob("*.mel.npy")):
31+
mel = torch.from_numpy(np.load(path)).unsqueeze(0).to(device)
32+
with torch.no_grad():
33+
z, c, _, _ = model(mel)
34+
35+
output = z.squeeze().cpu().numpy()
36+
time = np.linspace(0, (mel.size(-1) - 1) * hop_length_seconds, len(output))
37+
relative_path = path.relative_to(in_dir).with_suffix("")
38+
out_path = out_dir / relative_path
39+
out_path.parent.mkdir(exist_ok=True, parents=True)
40+
np.savez(out_path.with_suffix(".npz"), features=output, time=time)
41+
42+
43+
if __name__ == "__main__":
44+
parser = argparse.ArgumentParser()
45+
parser.add_argument("--checkpoint", type=str, help="Checkpoint path to resume")
46+
parser.add_argument("--in-dir", type=str, help="Directory to encode")
47+
parser.add_argument("--out-dir", type=str, help="Output path")
48+
args = parser.parse_args()
49+
with open("config.json") as file:
50+
params = json.load(file)
51+
encode_dataset(args, params)

0 commit comments

Comments
 (0)