Skip to content

Commit

Permalink
Fix import paths in pretraining scripts
Browse files Browse the repository at this point in the history
  • Loading branch information
Luodian committed Nov 6, 2023
1 parent b1557c7 commit ca69589
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 3 deletions.
2 changes: 1 addition & 1 deletion pipeline/demos/demo_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def generate(self, image, prompt, no_image_flag=False):
formated_prompt = f"User: {prompt} Assistant:"
model_inputs = self.processor(text=formated_prompt, images=[raw_image_data] if no_image_flag is False else None, device=self.device)
for k, v in model_inputs.items():
model_inputs[k] = v.to(self.device)
model_inputs[k] = v.to(self.device, non_blocking=True) if isinstance(v, torch.Tensor) else [vv.to(self.device, non_blocking=True) for vv in v]

generation_output = self.model.generate(**model_inputs, max_new_tokens=self.max_new_tokens)
generation_text = self.processor.batch_decode(generation_output, skip_special_tokens=True)
Expand Down
2 changes: 1 addition & 1 deletion pipeline/train/pretraining.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from otter_ai import FlamingoForConditionalGeneration, OtterForConditionalGeneration

sys.path.append("../..")
from pipeline.train.data import get_data
from pipeline.mimicit_utils.data import get_data
from pipeline.train.distributed import world_info_from_env
from pipeline.train.train_utils import AverageMeter, get_checkpoint

Expand Down
2 changes: 1 addition & 1 deletion pipeline/train/pretraining_cc3m.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from otter_ai import FlamingoForConditionalGeneration, OtterForConditionalGeneration

sys.path.append("../..")
from pipeline.train.data import get_data
from pipeline.mimicit_utils.data import get_data
from pipeline.train.distributed import world_info_from_env
from pipeline.train.train_utils import AverageMeter, get_checkpoint

Expand Down

0 comments on commit ca69589

Please sign in to comment.