Skip to content

Commit

Permalink
unify test_engine and test_parallel_sampling
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Dec 14, 2023
1 parent 8012daa commit 8fb18c2
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 163 deletions.
53 changes: 29 additions & 24 deletions serve/tests/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,6 @@


def _test(args: argparse.Namespace):
# Examples. "--max-output-len" can be used to specify the number of output tokens.
#
# Profile the gpu memory usage, and use the maximum number of cache blocks possible:
# python serve/tests/test_engine_paged_cache_model.py --local-id vicuna-v1-7b-q4f16_ft --max-num-batched-tokens 2560 --max-input-len 256
#
# Mistral:
# python serve/tests/test_engine_paged_cache_model.py --local-id Mistral-7B-v0.1-q0f16 --long-prompt --max-num-batched-tokens 24000 --max-input-len 8000 --max-output-len 20
#
# Disco:
# python serve/tests/test_engine_paged_cache_model.py --local-id vicuna-v1-7b-q0f16-presharded-gpu2

engine_config = get_engine_config(
{
"use_staging_engine": args.use_staging_engine,
Expand Down Expand Up @@ -63,12 +52,17 @@ def _test(args: argparse.Namespace):
sampling_params_greedy = SamplingParams(
temperature=0.0,
)
sampling_params_random = SamplingParams(
temperature=1.0,
top_p=1.0,
)

if args.use_random_sampling:
sampling_params_random = SamplingParams(
temperature=1.0,
top_p=1.0,
)
num_sequences = args.num_sequences_to_sample

if num_sequences > 1:
sampling_params_choices = [sampling_params_random]
elif args.use_random_sampling:
# This tests different sampling types in the same batch
sampling_params_choices = [sampling_params_random, sampling_params_greedy]
else:
sampling_params_choices = [sampling_params_greedy]
Expand All @@ -79,9 +73,9 @@ def _test(args: argparse.Namespace):
else:
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
"The president of the United States is a powerful man. But he can also be",
"The future of AI is full of promise. But we need to carefully",
]

for i, prompt in enumerate(prompts):
Expand All @@ -95,25 +89,32 @@ def _test(args: argparse.Namespace):
max_tokens=args.max_output_len, stop_sequences=None
),
debug_options=DebugOptions(prompt=prompt),
num_sequences=num_sequences,
)
]
)

generated = ["" for _ in range(len(prompts))]
generated = [["" for _ in range(num_sequences)] for _ in range(len(prompts))]

while engine.has_pending_requests():
results = engine.step()
for res in results.outputs:
seq = res.sequences[0]
if not seq.is_finished:
generated[int(res.request_id)] += seq.delta
for i, seq in enumerate(res.sequences):
if not seq.is_finished:
generated[int(res.request_id)][i] += seq.delta

if args.long_prompt:
for g in generated:
print(f"Generated text = '{g}'")
for i, seq in enumerate(g):
print(f"Generated {i}-th sample = '{seq}'")
print("")
print("")
else:
for p, g in zip(prompts, generated):
print(f"Prompt = '{p}', generated text = '{g}'")
print(f"Prompt = '{p}'")
for i, seq in enumerate(g):
print(f"Generated {i}-th sample = '{seq}'")
print("")

if args.use_staging_engine:
engine.stop()
Expand All @@ -128,6 +129,7 @@ def _test(args: argparse.Namespace):
parser.add_argument("--max-output-len", type=int, default=20)
parser.add_argument("--long-prompt", action="store_true")
parser.add_argument("--use-random-sampling", action="store_true")
parser.add_argument("--num-sequences-to-sample", type=int, default=1)
parser.add_argument("--use-staging-engine", action="store_true")
parser.add_argument("--min-decode-steps", type=int, default=12)
parser.add_argument("--max-decode-steps", type=int, default=16)
Expand All @@ -143,6 +145,9 @@ def _test(args: argparse.Namespace):
args.max_input_len = 10000
args.max_num_sequences = 5

if args.num_sequences_to_sample > 1:
args.use_random_sampling = True

torch.manual_seed(args.seed)
torch.cuda.manual_seed(args.seed)

Expand Down
139 changes: 0 additions & 139 deletions serve/tests/test_parallel_sampling.py

This file was deleted.

0 comments on commit 8fb18c2

Please sign in to comment.