-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathgenerator-simple.py
38 lines (29 loc) · 1.29 KB
/
generator-simple.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import fire
import jsonlines
from progressbar import progressbar
from AdvDecoder import decode
from transformers import AutoTokenizer, AutoModelForCausalLM, GPT2Tokenizer, GPT2LMHeadModel, RobertaForSequenceClassification
from BatchTextGenerationPipeline import BatchTextGenerationPipeline
from IsFakePipeline import IsFakePipelineHF
def main(file, lines, sequences_per_step=12, sequence_length=64):
# tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
# model = GPT2LMHeadModel.from_pretrained("gpt2")
# model.to(0)
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neo-125M")
model = AutoModelForCausalLM.from_pretrained("EleutherAI/gpt-neo-125M")
model.to(0)
generator = BatchTextGenerationPipeline(model=model, tokenizer=tokenizer, device=0)
with jsonlines.open(file, mode='a') as writer:
for _ in progressbar(range(lines // sequences_per_step)):
sequences = generator.generate(
prompt='',
generate_length=sequence_length,
num_return_sequences=sequences_per_step,
do_sample=True,
top_p=0.99,
no_repeat_ngram_size=3
)
for text in sequences:
writer.write({'text': text})
if __name__ == '__main__':
fire.Fire(main)