-
Notifications
You must be signed in to change notification settings - Fork 92
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Installation problem #105
Comments
Update: if I change INVALID_POINT3D = np.uint64(-1) to INVALID_POINT3D = np.uint64(2**64 - 1) Code code now works up to this point: (zipnerf) caio@caio:~/2_company/OmniverseGaussianSplatIsaacSimProject/splats_folders/zipnerf-pytorch$ python -m train
2024-06-27 12:40:19: Config(dataset_loader='llff', batching='all_images', batch_size=8192, patch_size=1, factor=4, multiscale=False, multiscale_levels=4, forward_facing=False, render_path=False, llffhold=8, llff_use_all_images_for_training=False, llff_use_all_images_for_testing=False, use_tiffs=False, compute_disp_metrics=False, compute_normal_metrics=False, disable_multiscale_loss=False, randomized=True, near=2.0, far=6.0, exp_name='test', data_dir='/home/caio/2_company/OmniverseGaussianSplatIsaacSimProject/datasets/bicycle', vocab_tree_path=None, render_chunk_size=65536, num_showcase_images=5, deterministic_showcase=True, vis_num_rays=16, vis_decimate=0, dpcpp_backend=False, importance_sampling=False, max_steps=25000, early_exit_steps=None, checkpoint_every=5000, resume_from_checkpoint=True, checkpoints_total_limit=1, gradient_scaling=False, print_every=100, train_render_every=500, data_loss_type='charb', charb_padding=0.001, data_loss_mult=1.0, data_coarse_loss_mult=0.0, interlevel_loss_mult=0.0, anti_interlevel_loss_mult=0.01, orientation_loss_mult=0.0, orientation_coarse_loss_mult=0.0, orientation_loss_target='normals_pred', predicted_normal_loss_mult=0.0, predicted_normal_coarse_loss_mult=0.0, hash_decay_mults=0.1, lr_init=0.01, lr_final=0.001, lr_delay_steps=5000, lr_delay_mult=1e-08, adam_beta1=0.9, adam_beta2=0.99, adam_eps=1e-15, grad_max_norm=0.0, grad_max_val=0.0, distortion_loss_mult=0.005, opacity_loss_mult=0.0, eval_only_once=True, eval_save_output=True, eval_save_ray_data=False, eval_render_interval=1, eval_dataset_limit=2147483647, eval_quantize_metrics=True, eval_crop_borders=0, render_video_fps=60, render_video_crf=18, render_path_frames=120, z_variation=0.0, z_phase=0.0, render_dist_percentile=0.5, render_dist_curve_fn=<ufunc 'log'>, render_path_file=None, render_resolution=None, render_focal=None, render_camtype=None, render_spherical=False, render_save_async=True, render_spline_keyframes=None, render_spline_n_interp=30, render_spline_degree=5, render_spline_smoothness=0.03, render_spline_interpolate_exposure=False, rawnerf_mode=False, exposure_percentile=97.0, num_border_pixels_to_mask=0, apply_bayer_mask=False, autoexpose_renders=False, eval_raw_affine_cc=False, zero_glo=False, valid_weight_thresh=0.05, isosurface_threshold=20, mesh_voxels=134217728, visibility_resolution=512, mesh_radius=1.0, mesh_max_radius=10.0, std_value=0.0, compute_visibility=False, extract_visibility=True, decimate_target=-1, vertex_color=True, vertex_projection=True, tsdf_radius=2.0, tsdf_resolution=512, truncation_margin=5.0, tsdf_max_radius=10.0)
2024-06-27 12:40:19: Distributed environment: NO
Num processes: 1
Process index: 0
Local process index: 0
Device: cuda
Mixed precision type: no
Warning: image_path not found for reconstruction
Warning: image_path not found for reconstruction
2024-06-27 12:40:24: Checkpoint does not exist. Starting a new training run.
2024-06-27 12:40:24: Number of parameters being optimized: 130306473
2024-06-27 12:40:24: Begin training...
Training: 0%| | 0/25000 [00:00<?, ?it/s]
2024-06-27 12:40:25: Error!
Traceback (most recent call last):
File "/home/caio/anaconda3/envs/zipnerf/lib/python3.9/runpy.py", line 197, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/home/caio/anaconda3/envs/zipnerf/lib/python3.9/runpy.py", line 87, in _run_code
exec(code, run_globals)
File "/home/caio/.local/share/Trash/files/zipnerf-pytorch.2/train.py", line 390, in <module>
app.run(main)
File "/home/caio/anaconda3/envs/zipnerf/lib/python3.9/site-packages/absl/app.py", line 308, in run
_run_main(main, args)
File "/home/caio/anaconda3/envs/zipnerf/lib/python3.9/site-packages/absl/app.py", line 254, in _run_main
sys.exit(main(argv))
File "/home/caio/.local/share/Trash/files/zipnerf-pytorch.2/train.py", line 169, in main
renderings, ray_history = model(
File "/home/caio/anaconda3/envs/zipnerf/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/caio/anaconda3/envs/zipnerf/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
File "/home/caio/.local/share/Trash/files/zipnerf-pytorch.2/internal/models.py", line 229, in forward
ray_results = mlp(
File "/home/caio/anaconda3/envs/zipnerf/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/caio/anaconda3/envs/zipnerf/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
File "/home/caio/.local/share/Trash/files/zipnerf-pytorch.2/internal/models.py", line 523, in forward
raw_grad_density = torch.autograd.grad(
File "/home/caio/anaconda3/envs/zipnerf/lib/python3.9/site-packages/torch/autograd/__init__.py", line 412, in grad
result = _engine_run_backward(
File "/home/caio/anaconda3/envs/zipnerf/lib/python3.9/site-packages/torch/autograd/graph.py", line 744, in _engine_run_backward
return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
RuntimeError: One of the differentiated Tensors appears to not have been used in the graph. Set allow_unused=True if this is the desired behavior. |
Update: I also had to change this part of the code inside the forward method: if self.disable_density_normals:
raw_density, x, means_contract = self.predict_density(means, stds, rand=rand, no_warp=no_warp)
raw_grad_density = None
normals = None
else:
with torch.enable_grad():
means.requires_grad_(True)
raw_density, x, means_contract = self.predict_density(means, stds, rand=rand, no_warp=no_warp)
d_output = torch.ones_like(raw_density, requires_grad=False, device=raw_density.device)
raw_grad_density = torch.autograd.grad(
outputs=raw_density,
inputs=means,
grad_outputs=d_output,
create_graph=True,
retain_graph=True,
only_inputs=True,
allow_unused=True
)[0]
if raw_grad_density is not None:
raw_grad_density = raw_grad_density.mean(-2)
# Compute normal vectors as negative normalized density gradient.
# We normalize the gradient of raw (pre-activation) density because
# it's the same as post-activation density, but is more numerically stable
# when the activation function has a steep or flat gradient.
normals = -ref_utils.l2_normalize(raw_grad_density)
else:
# Handle the case where raw_grad_density is None
normals = None and then change the batch size (in the file batch_size: int = 2 ** 12 # The number of rays/pixels in each batch. I also had to change the version of matplotlib and numpy to make it work: pip install matplotlib==3.7.3
pip install numpy==1.26.4 Now the training is executing and will take almost 3 hours to train the bicycle scene :( |
So 3 hours are longer than your expectation? |
i have the same problem with you,maybe it is because the download link is changed now |
@caiobarrosv |
Guys, I'm having a lot of problems trying to execute the train.py file.
OS: Ubuntu 22.04
Graphics card: RTX 3060
Driver version: 545.29.06
I installed cuda 11.8 and configured bashrc accondingly:
The output of
nvcc --version
command is:Then I followed these steps:
These are the packages installed after running
pip install -r requirements.txt
:The first problem arises when I try to run the following command:
The output is:
It happens because the torch installed from the requirements.txt uses cuda 12.1:
$ python -c "import torch; print(torch.version.cuda)" 12.1
Therefore, I changed the cuda version for 12.1 in bashrc:
The g++ version is:
Even after changing the cuda version, I get a lot of errors:
extension_cuda.log
I then uninstalled ninja:
And changed the string
-std=c++14
to-std=c++17
in theextension/cuda/setup.py
.After this change, everything compiles:
Finally, I tried to install torch-scatter for CUDA 12.1 and torch 2.3.1 (version installed from the requirements.txt)
When I try to run the train.py script with the bycicle dataset I get:
Any thoughts on how to solve it? Thank you very much :)
The text was updated successfully, but these errors were encountered: