@@ -49,6 +49,45 @@ Séance 3
4949
5050* numpy, broadcasting
5151* implémentation d'un chi-deux sans boucle
52+ * comment implémenter la fonction `repeat_interleave
53+ <https://docs.pytorch.org/docs/stable/generated/torch.repeat_interleave.html> `_
54+ avec :epkg: `numpy ` et sans boucle ?
55+ En particulier cet exemple ``torch.repeat_interleave(y, torch.tensor([1, 2]), dim=0) ``
56+
57+ Un problème... que fait la fonction suivante ?
58+
59+ .. code-block :: python
60+
61+ def reshape_keep0 (arr , new_shape ):
62+ orig_shape = arr.shape
63+ final_shape = []
64+
65+ for i, dim in enumerate (new_shape):
66+ if dim == 0 :
67+ final_shape.append(orig_shape[i]) # garder dimension originale
68+ else :
69+ final_shape.append(dim)
70+ return arr.reshape(tuple (final_shape))
71+
72+ Comment construire une fonction qui retourne l'argument ``new_shape ``
73+ quand on connaît les dimensions de départ et d'arrivée ?
74+ La fonction doit valider les exemples suivants,
75+ chaque dimension sous forme de chaîne de caractères peut prendre n'importe
76+ quelle valeur.
77+
78+ .. code-block :: python
79+
80+ self .assertEqual((0 , 1024 , - 1 ), align((" d1" , 4 , 256 , " d2" ), (" d1" , 1024 , " d2" )))
81+ self .assertEqual((0 , 0 , 1024 ), align((" d1" , " d2" , 4 , 256 ), (" d1" , " d2" , 1024 )))
82+ self .assertEqual((6 , - 1 ), align((2 , 3 , " d1" ), (" a" , " d1" )))
83+ self .assertEqual((6 , - 1 ), align((2 , 3 , " d1" ), (6 , " d1" )))
84+ self .assertEqual((- 1 , 12 , 196 , 64 ), align((" d1" , 196 , 64 ), (" d2" , 12 , 196 , 64 )))
85+ self .assertEqual((- 1 , 196 , 64 ), align((" d1" , 196 , 64 ), (" d2" , 196 , 64 )))
86+ self .assertEqual((32 , 196 , 64 ), align((32 , 196 , 64 ), (32 , 196 , 64 )))
87+ self .assertEqual((4 , 8 , 196 , 64 ), align((32 , 196 , 64 ), (4 , 8 , 196 , 64 )))
88+ self .assertEqual((32 , 196 , 64 ), align((4 , 8 , 196 , 64 ), (32 , 196 , 64 )))
89+ self .assertEqual((0 , 196 , 64 ), align((" d1" , 196 , 64 ), (" d1" , 196 , 64 )))
90+ self .assertEqual((0 , 196 , 2 , 32 ), align((" d1" , 196 , 64 ), (" d1" , 196 , 2 , 32 )))
5291
5392 Séance 4
5493++++++++
0 commit comments