-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathshift.py
28 lines (22 loc) · 822 Bytes
/
shift.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
from typing import List, Text, Set, Any
import torch
from modifyinput import get_is_shifted
def shift_example(original: List[int], do_not_shift: Set[int], shift: int):
tmp = []
for x in original:
if x in do_not_shift:
tmp.append(x)
else:
tmp.append(x + shift)
return tmp
def add_shifted_input(original: List[List[int]], do_not_shift: Set[int], shift: int) -> None:
to_add = []
for example in original:
to_add.append(shift_example(example, do_not_shift, shift))
original.extend(to_add)
def remove_parallel_data(original: List[List[int]]) -> None:
if len(original) % 4 != 0:
raise ValueError("Data not parallel at all?")
quarter = int(len(original) / 4)
modified = original[:quarter] + original[-quarter:]
return modified