Skip to content

Commit

Permalink
Implemented right and full join #13
Browse files Browse the repository at this point in the history
  • Loading branch information
piccolbo committed Jul 14, 2015
1 parent a5bcdbd commit 6a198b2
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 7 deletions.
2 changes: 2 additions & 0 deletions pkg/NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ S3method(sql_join, SparkSQLConnection)
S3method(src_desc, src_SparkSQL)
S3method(src_translate_env, src_SparkSQL)
S3method(tbl, src_SparkSQL)
S3method(right_join, tbl_SparkSQL)
S3method(full_join, tbl_SparkSQL)
S3method(intersect, tbl_SparkSQL)
S3method(union, tbl_SparkSQL)
export(src_SparkSQL)
Expand Down
20 changes: 19 additions & 1 deletion pkg/R/src-sparkSQL.R
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,24 @@ union.tbl_SparkSQL =
#under MIT license
intersect.tbl_SparkSQL =
function (x, y, copy = FALSE, ...){
if(!all(colnames(x) == colnames(y)))
if(suppressWarnings(!all(colnames(x) == colnames(y))))
stop("Tables not compatible")
inner_join(x, y, copy = copy)}

#modeled after join methods in http://github.com/hadley/dplyr,
#under MIT license
some_join =
function (x, y, by = NULL, copy = FALSE, auto_index = FALSE, ..., type) {
by <- dplyr:::common_by(by, x, y)
y <- dplyr:::auto_copy(x, y, copy, indexes = if (auto_index)
list(by$y))
sql <- dplyr:::sql_join(x$src$con, x, y, type = type, by = by)
dplyr:::update.tbl_sql(tbl(x$src, sql), group_by = groups(x))}

right_join.tbl_SparkSQL =
function (x, y, by = NULL, copy = FALSE, auto_index = FALSE, ...) {
some_join(x = x, y = y, by = by, copy = copy, auto_index = auto_index, ..., type = "right")}

full_join.tbl_SparkSQL =
function (x, y, by = NULL, copy = FALSE, auto_index = FALSE, ...) {
some_join(x = x, y = y, by = by, copy = copy, auto_index = auto_index, ..., type = "full")}
13 changes: 7 additions & 6 deletions pkg/tests/two-table.R
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ library(dplyr)
library(dplyr.spark)
Sys.setenv(
HADOOP_JAR =
"/Users/antonio/Projects/Revolution/spark/assembly/target/scala-2.10/spark-assembly-1.4.0-SNAPSHOT-hadoop2.6.0.jar")
"/Users/antonio/Projects/Revolution/spark/assembly/target/scala-2.10/spark-assembly-1.4.0-SNAPSHOT-hadoop2.4.0.jar")
assignInNamespace(
"unique_name",
function()
Expand Down Expand Up @@ -63,17 +63,18 @@ df2 = copy_to(my_db, df2, temporary = TRUE)
df1 = tbl(my_db, "df1")
df2 = tbl(my_db, "df2")

df1 %>% inner_join(df2)
collect(df1 %>% inner_join(df2))

df1 %>% left_join(df2)
collect(df1 %>% left_join(df2))

# broken
# df1 %>% right_join(df2)
#
collect(df1 %>% right_join(df2))

df2 %>% left_join(df1)
collect(df2 %>% left_join(df1))

#broken
#df1 %>% full_join(df2)
collect(df1 %>% full_join(df2))

planes = copy_to(my_db, planes)
planes = tbl(my_db, "planes")
Expand Down

0 comments on commit 6a198b2

Please sign in to comment.