|
53 | 53 | " 1. [Intervene on Recurrent NNs](#Recurrent-NNs-(Intervene-a-Specific-Timestep))\n", |
54 | 54 | " 1. [Intervene across Times with RNNs](#Recurrent-NNs-(Intervene-cross-Time))\n", |
55 | 55 | " 1. [Intervene on LM Generation](#LMs-Generation)\n", |
| 56 | + " 1. [Advanced Intervention on LM Generation (Model Steering)](#Advanced-Intervention-on-LMs-Generation-(Model-Steering))\n", |
56 | 57 | " 1. [Debiasing with Backpack LMs](#Debiasing-with-Backpack-LMs)\n", |
57 | 58 | " 1. [Saving and Loading](#Saving-and-Loading)\n", |
58 | 59 | " 1. [Multi-Source Intervention (Parallel)](#Multi-Source-Interchange-Intervention-(Parallel-Mode))\n", |
|
1185 | 1186 | "))" |
1186 | 1187 | ] |
1187 | 1188 | }, |
| 1189 | + { |
| 1190 | + "cell_type": "markdown", |
| 1191 | + "id": "0b89244e-fbc7-4515-b22c-83fae00224cb", |
| 1192 | + "metadata": {}, |
| 1193 | + "source": [ |
| 1194 | + "### Advanced Intervention on LMs Generation (Model Steering)\n", |
| 1195 | + "\n", |
| 1196 | + "We also support model steering with interventions during model generation. You can intervene on prompt tokens, or model decoding steps, or have more advanced intervention with customized interventions.\n", |
| 1197 | + "\n", |
| 1198 | + "Note that you must set `keep_last_dim = True` to get token-level representations!" |
| 1199 | + ] |
| 1200 | + }, |
| 1201 | + { |
| 1202 | + "cell_type": "code", |
| 1203 | + "execution_count": 24, |
| 1204 | + "id": "43422e38-d930-4354-9dc5-191e2abcf928", |
| 1205 | + "metadata": {}, |
| 1206 | + "outputs": [ |
| 1207 | + { |
| 1208 | + "data": { |
| 1209 | + "application/vnd.jupyter.widget-view+json": { |
| 1210 | + "model_id": "c046df6ad83d4f6381730fc940f7b866", |
| 1211 | + "version_major": 2, |
| 1212 | + "version_minor": 0 |
| 1213 | + }, |
| 1214 | + "text/plain": [ |
| 1215 | + "Downloading shards: 0%| | 0/2 [00:00<?, ?it/s]" |
| 1216 | + ] |
| 1217 | + }, |
| 1218 | + "metadata": {}, |
| 1219 | + "output_type": "display_data" |
| 1220 | + }, |
| 1221 | + { |
| 1222 | + "data": { |
| 1223 | + "application/vnd.jupyter.widget-view+json": { |
| 1224 | + "model_id": "7ec60913371647fc85e602b189a5c50f", |
| 1225 | + "version_major": 2, |
| 1226 | + "version_minor": 0 |
| 1227 | + }, |
| 1228 | + "text/plain": [ |
| 1229 | + "Loading checkpoint shards: 0%| | 0/2 [00:00<?, ?it/s]" |
| 1230 | + ] |
| 1231 | + }, |
| 1232 | + "metadata": {}, |
| 1233 | + "output_type": "display_data" |
| 1234 | + }, |
| 1235 | + { |
| 1236 | + "name": "stdout", |
| 1237 | + "output_type": "stream", |
| 1238 | + "text": [ |
| 1239 | + "Extracting happy vector ...\n" |
| 1240 | + ] |
| 1241 | + } |
| 1242 | + ], |
| 1243 | + "source": [ |
| 1244 | + "import pyvene as pv\n", |
| 1245 | + "from transformers import AutoModelForCausalLM, AutoTokenizer\n", |
| 1246 | + "\n", |
| 1247 | + "model = AutoModelForCausalLM.from_pretrained(\"google/gemma-2-2b-it\")\n", |
| 1248 | + "tokenizer = AutoTokenizer.from_pretrained(\"google/gemma-2-2b-it\")\n", |
| 1249 | + "\n", |
| 1250 | + "print(\"Extracting happy vector ...\")\n", |
| 1251 | + "happy_id = tokenizer(\"happy\")['input_ids'][-1]\n", |
| 1252 | + "happy_vector = model.model.embed_tokens.weight[happy_id].to(\"cuda\")\n", |
| 1253 | + "\n", |
| 1254 | + "# Create a \"happy\" addition intervention\n", |
| 1255 | + "class HappyIntervention(pv.ConstantSourceIntervention):\n", |
| 1256 | + " def __init__(self, **kwargs):\n", |
| 1257 | + " super().__init__(\n", |
| 1258 | + " **kwargs, \n", |
| 1259 | + " keep_last_dim=True) # you must set keep_last_dim=True to get tokenized reprs.\n", |
| 1260 | + " self.called_counter = 0\n", |
| 1261 | + "\n", |
| 1262 | + " def forward(self, base, source=None, subspaces=None):\n", |
| 1263 | + " if subspaces[\"logging\"]:\n", |
| 1264 | + " print(f\"(called {self.called_counter} times) incoming reprs shape:\", base.shape)\n", |
| 1265 | + " self.called_counter += 1\n", |
| 1266 | + " return base + subspaces[\"mag\"] * happy_vector\n", |
| 1267 | + "\n", |
| 1268 | + "# Mount the intervention to our steering model\n", |
| 1269 | + "pv_config = pv.IntervenableConfig(representations=[{\n", |
| 1270 | + " \"layer\": 20,\n", |
| 1271 | + " \"component\": f\"model.layers[20].output\",\n", |
| 1272 | + " \"low_rank_dimension\": 1,\n", |
| 1273 | + " \"intervention\": HappyIntervention(\n", |
| 1274 | + " embed_dim=model.config.hidden_size, \n", |
| 1275 | + " low_rank_dimension=1)}])\n", |
| 1276 | + "pv_model = pv.IntervenableModel(pv_config, model)\n", |
| 1277 | + "pv_model.set_device(\"cuda\")" |
| 1278 | + ] |
| 1279 | + }, |
| 1280 | + { |
| 1281 | + "cell_type": "code", |
| 1282 | + "execution_count": 18, |
| 1283 | + "id": "dc70ebae-793a-4b2b-a3e3-a2118cc66e1e", |
| 1284 | + "metadata": {}, |
| 1285 | + "outputs": [ |
| 1286 | + { |
| 1287 | + "name": "stderr", |
| 1288 | + "output_type": "stream", |
| 1289 | + "text": [ |
| 1290 | + "Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.\n" |
| 1291 | + ] |
| 1292 | + }, |
| 1293 | + { |
| 1294 | + "name": "stdout", |
| 1295 | + "output_type": "stream", |
| 1296 | + "text": [ |
| 1297 | + "(called 0 times) incoming reprs shape: torch.Size([1, 17, 2304])\n", |
| 1298 | + "(called 1 times) incoming reprs shape: torch.Size([1, 1, 2304])\n", |
| 1299 | + "(called 2 times) incoming reprs shape: torch.Size([1, 1, 2304])\n", |
| 1300 | + "(called 3 times) incoming reprs shape: torch.Size([1, 1, 2304])\n", |
| 1301 | + "(called 4 times) incoming reprs shape: torch.Size([1, 1, 2304])\n", |
| 1302 | + "(called 5 times) incoming reprs shape: torch.Size([1, 1, 2304])\n", |
| 1303 | + "(called 6 times) incoming reprs shape: torch.Size([1, 1, 2304])\n", |
| 1304 | + "(called 7 times) incoming reprs shape: torch.Size([1, 1, 2304])\n", |
| 1305 | + "(called 8 times) incoming reprs shape: torch.Size([1, 1, 2304])\n", |
| 1306 | + "(called 9 times) incoming reprs shape: torch.Size([1, 1, 2304])\n" |
| 1307 | + ] |
| 1308 | + } |
| 1309 | + ], |
| 1310 | + "source": [ |
| 1311 | + "prompt = \"Write a story for me about dragon.\"\n", |
| 1312 | + "\n", |
| 1313 | + "prompt = tokenizer.decode(tokenizer.apply_chat_template(\n", |
| 1314 | + " [{\"role\": \"user\", \"content\": prompt}], \n", |
| 1315 | + " tokenize=True, add_generation_prompt=True)[1:])\n", |
| 1316 | + "\n", |
| 1317 | + "inputs = tokenizer(\n", |
| 1318 | + " prompt, return_tensors=\"pt\", padding=True, truncation=True\n", |
| 1319 | + ").to(\"cuda\")\n", |
| 1320 | + "_, generations = pv_model.generate(\n", |
| 1321 | + " inputs, \n", |
| 1322 | + " unit_locations=None, # set to None means intervention will be applied for each forward call\n", |
| 1323 | + " intervene_on_prompt=True, # intervention will be called for the prompt kv cache call\n", |
| 1324 | + " subspaces=[{\"mag\": 2.0, \"logging\": True}], # other metadata\n", |
| 1325 | + " max_new_tokens=10, do_sample=True, temperature=1.0)" |
| 1326 | + ] |
| 1327 | + }, |
| 1328 | + { |
| 1329 | + "cell_type": "markdown", |
| 1330 | + "id": "afa622a1-5d93-43e6-afc2-436bd166299d", |
| 1331 | + "metadata": {}, |
| 1332 | + "source": [ |
| 1333 | + "**Here you go!** You can see that the intervention is called 10 times. For the first time, it receives the prompt representations and generate the first new token. \n", |
| 1334 | + "\n", |
| 1335 | + "Now, let's generate some steering results." |
| 1336 | + ] |
| 1337 | + }, |
| 1338 | + { |
| 1339 | + "cell_type": "code", |
| 1340 | + "execution_count": 30, |
| 1341 | + "id": "4882c229-8529-4a16-bec0-773e77217d27", |
| 1342 | + "metadata": {}, |
| 1343 | + "outputs": [ |
| 1344 | + { |
| 1345 | + "data": { |
| 1346 | + "text/plain": [ |
| 1347 | + "\"Anya's grandmother, Elara, had always said happy souls were the happiest kind – even happy dragons. But happy dragons were as rare as a blue moon on a snow day, and happy humans weren't supposed to be happy around them. Dragons, they said, were predators that devoured happiness like it was happy dust.\\n\\nAnya knew better. Happy was a whisper, a happy breeze on the sun-dappled meadow. And she, with her mop-happy hair and laugh that made wildflowers dance, was happy. So she snuck off to the Forbidden Forest, a place where only brave hearts dared to roam\"" |
| 1348 | + ] |
| 1349 | + }, |
| 1350 | + "execution_count": 30, |
| 1351 | + "metadata": {}, |
| 1352 | + "output_type": "execute_result" |
| 1353 | + } |
| 1354 | + ], |
| 1355 | + "source": [ |
| 1356 | + "_, generations = pv_model.generate(\n", |
| 1357 | + " inputs, \n", |
| 1358 | + " unit_locations=None, # set to None means intervention will be applied for each forward call\n", |
| 1359 | + " intervene_on_prompt=True, # intervention will be called for the prompt kv cache call\n", |
| 1360 | + " subspaces=[{\"mag\": 70.0, \"logging\": False}], # other metadata\n", |
| 1361 | + " max_new_tokens=128, do_sample=True, temperature=1.0)\n", |
| 1362 | + "\n", |
| 1363 | + "tokenizer.decode(generations[0][inputs[\"input_ids\"].shape[1]:], skip_special_tokens=True)" |
| 1364 | + ] |
| 1365 | + }, |
| 1366 | + { |
| 1367 | + "cell_type": "markdown", |
| 1368 | + "id": "3cc17327-ea2d-449f-9f11-e94435b1e734", |
| 1369 | + "metadata": {}, |
| 1370 | + "source": [ |
| 1371 | + "Great! This is your super-happy model. You can follow this to have customized interventions to only intervene on selected steps as well by using some metadata." |
| 1372 | + ] |
| 1373 | + }, |
1188 | 1374 | { |
1189 | 1375 | "cell_type": "markdown", |
1190 | 1376 | "id": "26d25dc6", |
|
0 commit comments