From cf3c531e012c9c0ddc87485cab0ae3659ab976ae Mon Sep 17 00:00:00 2001 From: snordmann Date: Mon, 20 Jan 2025 03:18:37 -0800 Subject: [PATCH 1/5] Host Ir: add linear op with preallocated outputs --- csrc/host_ir/executor.cpp | 55 +++++++++++++++++++++++++ csrc/host_ir/executor.h | 1 + csrc/ir/internal_nodes.h | 1 - tests/cpp/test_host_irs.cpp | 82 +++++++++++++++++++++++++++++++++++++ 4 files changed, 138 insertions(+), 1 deletion(-) diff --git a/csrc/host_ir/executor.cpp b/csrc/host_ir/executor.cpp index 0f9f3da6921..102ae5feb3e 100644 --- a/csrc/host_ir/executor.cpp +++ b/csrc/host_ir/executor.cpp @@ -545,6 +545,61 @@ void HostIrEvaluator::handle(MatmulOp* matmul) { } } +void HostIrEvaluator::handle(LinearOp* linear) { + TensorView* in = linear->inA()->as(); + TensorView* weight = linear->inB()->as(); + TensorView* bias = linear->bias()->as(); + TensorView* out = linear->out()->as(); + NVF_ERROR( + expr_evaluator_.isKnown(in) + && expr_evaluator_.isKnown(weight) + && (!linear->has_bias() || expr_evaluator_.isKnown(bias)), + "Inputs of the Linear Op ", + linear->toString(), + "must be precomputed before being retrieved"); + + if (!expr_evaluator_.isKnown(out)) { + unhandled(linear); + return; + } + + auto squeeze_device_dims = [](at::Tensor& t, + int64_t num_device_dims) -> void { + // Record the initial shape for the error message. + std::vector shape = t.sizes().vec(); + for ([[maybe_unused]] auto _ : c10::irange(num_device_dims)) { + NVF_CHECK( + t.size(0) == 1, + "When the weight is >2D, expect its preceding dimensions and " + "the bias's preceding dimensions to " + "be DID-parallel and therefore size-1: ", + shape); + t = t.squeeze(0); + } + }; + + auto in_at = expr_evaluator_.evaluate(in).as(); + auto weight_at = expr_evaluator_.evaluate(weight).as(); + auto bias_at = expr_evaluator_.evaluate(bias).as(); + auto out_at = expr_evaluator_.evaluate(out).as(); + + // The squeezes and unsqueezes are currently required to support a sharded + // linear layer. Remove them after #2563. + auto num_device_dims = weight_at.dim() - 2; + squeeze_device_dims(weight_at, num_device_dims); + if (linear->has_bias()) { + squeeze_device_dims(bias_at, num_device_dims); + at::linear_out(out_at, in_at, weight_at, bias_at); + } else { + at::linear_out(out_at, in_at, weight_at); + } + + for ([[maybe_unused]] auto _ : c10::irange(num_device_dims)) { + out_at = out_at.unsqueeze(0); + } + expr_evaluator_.bind(out, out_at, /*evaluate_validate=*/false); +} + void HostIrEvaluator::handle(kir::Allocate* allocate) { NVF_ERROR( allocate->buffer()->isA(), diff --git a/csrc/host_ir/executor.h b/csrc/host_ir/executor.h index 2797948975a..ad3e8422ca1 100644 --- a/csrc/host_ir/executor.h +++ b/csrc/host_ir/executor.h @@ -127,6 +127,7 @@ class HostIrEvaluator final : public OptOutDispatch { void handle(EndCoalescing* end_coalescing) override; void handle(kir::IfThenElse* if_then_else) override; void handle(MatmulOp* matmul) override; + void handle(LinearOp* linear) override; void handle(kir::Allocate* allocate) override; void unhandled(Statement* stmt) override; diff --git a/csrc/ir/internal_nodes.h b/csrc/ir/internal_nodes.h index 6aebcb3c457..ff2e29b04af 100644 --- a/csrc/ir/internal_nodes.h +++ b/csrc/ir/internal_nodes.h @@ -2269,7 +2269,6 @@ class LinearOp : public Expr { const ExpressionEvaluator& ee, const std::vector& inputs) const override; - private: bool has_bias() const { return inputs().size() == 3; } diff --git a/tests/cpp/test_host_irs.cpp b/tests/cpp/test_host_irs.cpp index e97550309e1..54832acc8ac 100644 --- a/tests/cpp/test_host_irs.cpp +++ b/tests/cpp/test_host_irs.cpp @@ -849,6 +849,88 @@ TEST_F(MatmulHostIrTest, HostIrMatmulOut) { EXPECT_TRUE(ref_output.allclose(c_tensor)); } +using LinearHostIrTest = NVFuserTest; + +TEST_F(LinearHostIrTest, HostIr) { + constexpr int64_t B = 32; + constexpr int64_t M = 64; + constexpr int64_t K = 128; + constexpr int64_t N = 256; + + auto hic = std::make_unique(); + FusionGuard fg(hic.get()); + + TensorView* in = makeContigTensor(3); + TensorView* weight = makeContigTensor(2); + TensorView* bias = makeContigTensor(1); + TensorView* out = linear(in, weight, bias); + + hic->addInput(in); + hic->addInput(weight); + hic->addInput(bias); + hic->addOutput(out); + + hic->pushBackTopLevelExprs(out->definition()); + + HostIrEvaluator hie(std::move(hic)); + + auto options = at::TensorOptions().device(at::kCUDA, 0).dtype(torch::kFloat); + at::Tensor in_at = at::randn({B, M, K}, options); + at::Tensor weight_at = at::randn({N, K}, options); + at::Tensor bias_at = at::randn({N}, options); + std::unordered_map concrete_input_buffers = { + {hie.inputs().at(0), in_at}, {hie.inputs().at(1), weight_at}, {hie.inputs().at(2), bias_at}}; + + auto output = hie.runWithInput(concrete_input_buffers).at(0); + + // validate + auto ref_output = at::linear(in_at, weight_at, bias_at); + + EXPECT_TRUE(ref_output.allclose(output)); +} + +TEST_F(LinearHostIrTest, HostIrLinearOut) { + constexpr int64_t B = 32; + constexpr int64_t M = 64; + constexpr int64_t K = 128; + constexpr int64_t N = 256; + + auto hic = std::make_unique(); + FusionGuard fg(hic.get()); + + TensorView* in = makeContigTensor(3); + TensorView* weight = makeContigTensor(2); + TensorView* bias = makeContigTensor(1); + TensorView* out = makeContigTensor(3); + + auto linear_op = IrBuilder::create(out, in, weight, bias); + + hic->addInput(in); + hic->addInput(weight); + hic->addInput(bias); + hic->addInput(out); + hic->addOutput(out); + + hic->pushBackTopLevelExprs(linear_op); + + HostIrEvaluator hie(std::move(hic)); + + auto options = at::TensorOptions().device(at::kCUDA, 0).dtype(torch::kFloat); + at::Tensor in_at = at::randn({B, M, K}, options); + at::Tensor weight_at = at::randn({N, K}, options); + at::Tensor bias_at = at::randn({N}, options); + at::Tensor out_at = at::empty({B, M, N}, options); + std::unordered_map concrete_input_buffers = { + {hie.inputs().at(0), in_at}, {hie.inputs().at(1), weight_at}, {hie.inputs().at(2), bias_at}, {hie.inputs().at(3), out_at}}; + + hie.runWithInput(concrete_input_buffers); + + // validate + auto ref_output = at::linear(in_at, weight_at, bias_at); + + EXPECT_TRUE(ref_output.allclose(out_at)); +} + using SelectHostIrTestParams = bool; using SelectHostIrTest = NVFuserFixtureParamTest; From 9e179334edf213686718fe3db8040804ee642e46 Mon Sep 17 00:00:00 2001 From: snordmann Date: Mon, 20 Jan 2025 03:23:22 -0800 Subject: [PATCH 2/5] slightly simplify implementation and test --- csrc/host_ir/executor.cpp | 29 ++--------------------------- tests/cpp/test_host_irs.cpp | 1 - 2 files changed, 2 insertions(+), 28 deletions(-) diff --git a/csrc/host_ir/executor.cpp b/csrc/host_ir/executor.cpp index 102ae5feb3e..f943ccc6927 100644 --- a/csrc/host_ir/executor.cpp +++ b/csrc/host_ir/executor.cpp @@ -563,41 +563,16 @@ void HostIrEvaluator::handle(LinearOp* linear) { return; } - auto squeeze_device_dims = [](at::Tensor& t, - int64_t num_device_dims) -> void { - // Record the initial shape for the error message. - std::vector shape = t.sizes().vec(); - for ([[maybe_unused]] auto _ : c10::irange(num_device_dims)) { - NVF_CHECK( - t.size(0) == 1, - "When the weight is >2D, expect its preceding dimensions and " - "the bias's preceding dimensions to " - "be DID-parallel and therefore size-1: ", - shape); - t = t.squeeze(0); - } - }; - auto in_at = expr_evaluator_.evaluate(in).as(); auto weight_at = expr_evaluator_.evaluate(weight).as(); auto bias_at = expr_evaluator_.evaluate(bias).as(); auto out_at = expr_evaluator_.evaluate(out).as(); - // The squeezes and unsqueezes are currently required to support a sharded - // linear layer. Remove them after #2563. - auto num_device_dims = weight_at.dim() - 2; - squeeze_device_dims(weight_at, num_device_dims); if (linear->has_bias()) { - squeeze_device_dims(bias_at, num_device_dims); - at::linear_out(out_at, in_at, weight_at, bias_at); + at::linear_out(out_at, in_at, weight_at.squeeze(), bias_at.squeeze()); } else { - at::linear_out(out_at, in_at, weight_at); - } - - for ([[maybe_unused]] auto _ : c10::irange(num_device_dims)) { - out_at = out_at.unsqueeze(0); + at::linear_out(out_at, in_at, weight_at.squeeze()); } - expr_evaluator_.bind(out, out_at, /*evaluate_validate=*/false); } void HostIrEvaluator::handle(kir::Allocate* allocate) { diff --git a/tests/cpp/test_host_irs.cpp b/tests/cpp/test_host_irs.cpp index 54832acc8ac..687072e172c 100644 --- a/tests/cpp/test_host_irs.cpp +++ b/tests/cpp/test_host_irs.cpp @@ -909,7 +909,6 @@ TEST_F(LinearHostIrTest, HostIrLinearOut) { hic->addInput(weight); hic->addInput(bias); hic->addInput(out); - hic->addOutput(out); hic->pushBackTopLevelExprs(linear_op); From d4326626caf57ff9b0f65d1bb4d433fcc982381f Mon Sep 17 00:00:00 2001 From: snordmann Date: Mon, 20 Jan 2025 03:24:06 -0800 Subject: [PATCH 3/5] lint --- csrc/host_ir/executor.cpp | 5 ++--- tests/cpp/test_host_irs.cpp | 9 +++++++-- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/csrc/host_ir/executor.cpp b/csrc/host_ir/executor.cpp index f943ccc6927..170f4fcc6dc 100644 --- a/csrc/host_ir/executor.cpp +++ b/csrc/host_ir/executor.cpp @@ -551,9 +551,8 @@ void HostIrEvaluator::handle(LinearOp* linear) { TensorView* bias = linear->bias()->as(); TensorView* out = linear->out()->as(); NVF_ERROR( - expr_evaluator_.isKnown(in) - && expr_evaluator_.isKnown(weight) - && (!linear->has_bias() || expr_evaluator_.isKnown(bias)), + expr_evaluator_.isKnown(in) && expr_evaluator_.isKnown(weight) && + (!linear->has_bias() || expr_evaluator_.isKnown(bias)), "Inputs of the Linear Op ", linear->toString(), "must be precomputed before being retrieved"); diff --git a/tests/cpp/test_host_irs.cpp b/tests/cpp/test_host_irs.cpp index 687072e172c..e0f41c70a91 100644 --- a/tests/cpp/test_host_irs.cpp +++ b/tests/cpp/test_host_irs.cpp @@ -879,7 +879,9 @@ TEST_F(LinearHostIrTest, HostIr) { at::Tensor weight_at = at::randn({N, K}, options); at::Tensor bias_at = at::randn({N}, options); std::unordered_map concrete_input_buffers = { - {hie.inputs().at(0), in_at}, {hie.inputs().at(1), weight_at}, {hie.inputs().at(2), bias_at}}; + {hie.inputs().at(0), in_at}, + {hie.inputs().at(1), weight_at}, + {hie.inputs().at(2), bias_at}}; auto output = hie.runWithInput(concrete_input_buffers).at(0); @@ -920,7 +922,10 @@ TEST_F(LinearHostIrTest, HostIrLinearOut) { at::Tensor bias_at = at::randn({N}, options); at::Tensor out_at = at::empty({B, M, N}, options); std::unordered_map concrete_input_buffers = { - {hie.inputs().at(0), in_at}, {hie.inputs().at(1), weight_at}, {hie.inputs().at(2), bias_at}, {hie.inputs().at(3), out_at}}; + {hie.inputs().at(0), in_at}, + {hie.inputs().at(1), weight_at}, + {hie.inputs().at(2), bias_at}, + {hie.inputs().at(3), out_at}}; hie.runWithInput(concrete_input_buffers); From bff1abb377fca78203248f7b0d7f6c81f7941e0e Mon Sep 17 00:00:00 2001 From: snordmann Date: Mon, 20 Jan 2025 04:37:01 -0800 Subject: [PATCH 4/5] Lower stream-parallelized LinearOp into Host IR AG+GEMM overlap algo --- csrc/host_ir/lower.cpp | 85 ++++++++++++++++++++------ tests/cpp/test_multidevice_host_ir.cpp | 61 ++++++++++++++++++ 2 files changed, 127 insertions(+), 19 deletions(-) diff --git a/csrc/host_ir/lower.cpp b/csrc/host_ir/lower.cpp index ea52ba5eeb6..1618a33e908 100644 --- a/csrc/host_ir/lower.cpp +++ b/csrc/host_ir/lower.cpp @@ -236,7 +236,7 @@ void lowerToReduceScatter( std::vector HostIrLower::lower(Expr* c) { FusionGuard fg(c->fusion()); - if (c->isA()) { + if (c->isOneOf()) { return lowerToCollectiveBasedPipelinedGemmComm(c); } @@ -342,30 +342,70 @@ bool HostIrLower::canLower(Expr* expr, bool ignore_inner_resharding) { } return ldst->as()->opType() == LoadStoreOpType::Set; } else if (auto* matmul = dynamic_cast(expr)) { - // For now we only support c = matmul(a,b) when b,c are fully replicated and - // a is sharded on axis 1 + // For now we only support out = matmul(a,b) when b, out are fully + // replicated, a is sharded on axis 1, and out i stream-parallelized on axis + // 0. return !isSharded(matmul->inB()) && !isSharded(matmul->out()) && matmul->inA()->axis(0)->getParallelType() == ParallelType::Serial && getShardedLogicalAxis(matmul->inA(), ParallelType::DIDx) == 1 && matmul->out()->axis(0)->getParallelType() == ParallelType::Stream; + } else if (auto* linear = dynamic_cast(expr)) { + // For now we only support out = linear(a, b, bias) when b, bias, and out + // are fully replicated, a is sharded on axis 1, and out i + // stream-parallelized on axis 0. + auto* a = linear->inA()->as(); + auto* b = linear->inB()->as(); + auto* bias = linear->bias()->as(); + ; + auto* out = linear->out()->as(); + ; + return !isSharded(b) && !(linear->has_bias() && isSharded(bias)) && + !isSharded(out) && + a->axis(0)->getParallelType() == ParallelType::Serial && + getShardedLogicalAxis(a, ParallelType::DIDx) == 1 && + out->axis(0)->getParallelType() == ParallelType::Stream; } return false; } std::vector HostIrLower::lowerToCollectiveBasedPipelinedGemmComm( Expr* expr) { - auto matmul = expr->as(); - NVF_ERROR(matmul != nullptr, "Expect a MatmulOp, got", expr); - TensorView* tva = matmul->inA(); - TensorView* tvb = matmul->inB(); - TensorView* tvc = matmul->out(); + NVF_ERROR( + (expr->isOneOf()), + "Expect a MatmulOp or a LinearOp, but got", + expr); + + TensorView *tva, *tvb, *tv_bias, *tv_out; + if (auto* matmul = dynamic_cast(expr)) { + tva = matmul->inA(); + tvb = matmul->inB(); + tv_out = matmul->out(); + } else if (auto* linear = dynamic_cast(expr)) { + tva = linear->inA()->as(); + tvb = linear->inB()->as(); + tv_bias = linear->bias()->as(); + ; + tv_out = linear->out()->as(); + ; + NVF_ERROR( + !(linear->has_bias() && isSharded(tv_bias)), + "The bias ", + tv_bias, + " is expected to not be sharded"); + } + NVF_ERROR( !isSharded(tvb), "The B operand ", tvb, " is expected to not be sharded"); NVF_ERROR( - !isSharded(tvc), + !isSharded(tv_out), "The output ", - matmul->out(), + tv_out, " is expected to not be sharded"); + NVF_ERROR( + tv_out->axis(0)->getParallelType() == ParallelType::Stream, + "The output ", + tv_out, + " is expected to be stream-parallelized on axis 0"); const int64_t sharded_axis_index = getShardedLogicalAxis(tva, ParallelType::DIDx); IterDomain* stream_axis = tva->axis(0); @@ -388,9 +428,9 @@ std::vector HostIrLower::lowerToCollectiveBasedPipelinedGemmComm( auto* allocate_tva_allgathered = IrBuilder::create(tva_allgathered, MemoryType::Global); - tvc->setMemoryType(MemoryType::Global); - auto* allocate_tvc = - IrBuilder::create(tvc, MemoryType::Global); + tv_out->setMemoryType(MemoryType::Global); + auto* allocate_tv_out = + IrBuilder::create(tv_out, MemoryType::Global); auto* j = IrBuilder::create(DataType::Index); // running index of the for-loop @@ -417,14 +457,14 @@ std::vector HostIrLower::lowerToCollectiveBasedPipelinedGemmComm( TensorView* tva_j = select(tva, 0, j); TensorView* tva_allgathered_j = select(tva_allgathered, 0, j); - TensorView* tvc_j = select(tvc, 0, j); + TensorView* tv_out_j = select(tv_out, 0, j); NVF_ERROR( tva->hasDeviceMesh(), "The matmul's input ", tva, "is expected to have a DeviceMesh"); - for (auto tv : {tva_j, tva_allgathered_j, tvc_j}) { + for (auto tv : {tva_j, tva_allgathered_j, tv_out_j}) { tv->setDeviceMesh(tva->getDeviceMesh()); } @@ -435,7 +475,13 @@ std::vector HostIrLower::lowerToCollectiveBasedPipelinedGemmComm( /*team=*/tva->getDeviceMesh().vector()); auto* wait = IrBuilder::create(communication); - auto* mm = IrBuilder::create(tvc_j, tva_allgathered_j, tvb); + Expr* compute; + if (expr->isA()) { + compute = IrBuilder::create(tv_out_j, tva_allgathered_j, tvb); + } else if (expr->isA()) { + compute = + IrBuilder::create(tv_out_j, tva_allgathered_j, tvb, tv_bias); + } auto* set_back_original_stream = IrBuilder::create(original_stream); @@ -447,15 +493,16 @@ std::vector HostIrLower::lowerToCollectiveBasedPipelinedGemmComm( tva_allgathered_j->definition(), communication, wait, - tvc_j->definition(), - mm, + tv_out_j->definition(), + compute, set_back_original_stream, sync_stream}; for (Expr* expr : loop_body) { for_loop->body().push_back(expr); } - return {get_current_stream, allocate_tva_allgathered, allocate_tvc, for_loop}; + return { + get_current_stream, allocate_tva_allgathered, allocate_tv_out, for_loop}; } std::unique_ptr HostIrLower::lower( diff --git a/tests/cpp/test_multidevice_host_ir.cpp b/tests/cpp/test_multidevice_host_ir.cpp index ef05c4a45ac..113c048b115 100644 --- a/tests/cpp/test_multidevice_host_ir.cpp +++ b/tests/cpp/test_multidevice_host_ir.cpp @@ -408,6 +408,67 @@ TEST_F(OverlapDistributedMatmulTest, AG_matmul) { EXPECT_TRUE(torch::allclose(tc_ref, tc, 1e-2, 1e-2)); } +TEST_F(OverlapDistributedMatmulTest, AG_linear) { + constexpr int64_t M = 32768; + constexpr int64_t K = 32768; + constexpr int64_t N = 1024; + constexpr int64_t S = 8; + const int64_t D = communicator_->size(); + if (M % (D * S) != 0) { + GTEST_SKIP() << "M must be a multiple of D * S, but got M = " << M + << ", D = " << D << ", S = " << S; + } + + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + TensorView* in = makeContigTensor(4); //[S, DIDx(D), M/(S*D), K] + TensorView* weight = makeContigTensor(2); //[N, K] + TensorView* bias = makeContigTensor(1); //[N] + TensorView* out = linear(in, weight, bias); //[S, D, M/(S*D), N] + + fusion->addInput(in); + fusion->addInput(weight); + fusion->addInput(bias); + fusion->addOutput(out); + + auto mesh = DeviceMesh::createForNumDevices(D); + in->setDeviceMesh(mesh); + weight->setDeviceMesh(mesh); + bias->setDeviceMesh(mesh); + out->setDeviceMesh(mesh); + + in->axis(1)->parallelize(ParallelType::DIDx); + out->axis(0)->parallelize(ParallelType::Stream); + + MultiDeviceExecutor executor(std::move(fusion), *communicator_); + + auto tensor_options = + at::TensorOptions().dtype(at::kFloat).device(communicator_->device()); + at::Tensor in_at_unsharded = + at::randn({S, D, M / (S * D), K}, tensor_options); + at::Tensor in_at = in_at_unsharded.slice( + 1, communicator_->deviceId(), communicator_->deviceId() + 1); + at::Tensor weight_at = at::randn({N, K}, tensor_options); + at::Tensor bias_at = at::randn({N}, tensor_options); + at::Tensor out_ref = at::linear(in_at_unsharded, weight_at, bias_at); + + std::vector inputs = {in_at, weight_at, bias_at}; + at::Tensor out_at; + + constexpr int64_t kNumberOfIterations = 20; + constexpr int64_t kNumberOfWarmupIterations = 5; + for (auto i : c10::irange(kNumberOfIterations)) { + if (i == kNumberOfWarmupIterations) { + cudaProfilerStart(); + } + out_at = executor.runWithInput(inputs).at(0); + } + cudaProfilerStop(); + + EXPECT_TRUE(torch::allclose(out_ref, out_at, 1e-2, 1e-2)); +} + } // namespace hir } // namespace nvfuser From 1fff21a44a548ddc85c753af81f3f2994173999b Mon Sep 17 00:00:00 2001 From: snordmann Date: Mon, 20 Jan 2025 05:12:42 -0800 Subject: [PATCH 5/5] lint --- csrc/host_ir/lower.cpp | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/csrc/host_ir/lower.cpp b/csrc/host_ir/lower.cpp index 1618a33e908..ca8a3bbe0db 100644 --- a/csrc/host_ir/lower.cpp +++ b/csrc/host_ir/lower.cpp @@ -374,19 +374,20 @@ std::vector HostIrLower::lowerToCollectiveBasedPipelinedGemmComm( (expr->isOneOf()), "Expect a MatmulOp or a LinearOp, but got", expr); - - TensorView *tva, *tvb, *tv_bias, *tv_out; + TensorView* tva = nullptr; + TensorView* tvb = nullptr; + TensorView* tv_bias = nullptr; + TensorView* tv_out = nullptr; if (auto* matmul = dynamic_cast(expr)) { tva = matmul->inA(); tvb = matmul->inB(); tv_out = matmul->out(); - } else if (auto* linear = dynamic_cast(expr)) { + } else { + auto* linear = dynamic_cast(expr); tva = linear->inA()->as(); tvb = linear->inB()->as(); tv_bias = linear->bias()->as(); - ; tv_out = linear->out()->as(); - ; NVF_ERROR( !(linear->has_bias() && isSharded(tv_bias)), "The bias ", @@ -475,10 +476,10 @@ std::vector HostIrLower::lowerToCollectiveBasedPipelinedGemmComm( /*team=*/tva->getDeviceMesh().vector()); auto* wait = IrBuilder::create(communication); - Expr* compute; + Expr* compute = nullptr; if (expr->isA()) { compute = IrBuilder::create(tv_out_j, tva_allgathered_j, tvb); - } else if (expr->isA()) { + } else { compute = IrBuilder::create(tv_out_j, tva_allgathered_j, tvb, tv_bias); }