Skip to content
This repository was archived by the owner on Oct 8, 2020. It is now read-only.

Commit f482bb8

Browse files
Aligned rule to SQL converter.
1 parent 78085b3 commit f482bb8

3 files changed

Lines changed: 180 additions & 21 deletions

File tree

sansa-inference-common/src/main/scala/net/sansa_stack/inference/rules/plan/SimplePlanGenerator.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ class SimplePlanGenerator(schema: SchemaPlus) {
8282
*/
8383
def generateLogicalPlan(rules: Seq[Rule]): RelNode = {
8484
// generate SQL query
85-
val sqlQuery = rules.map(sqlGenerator.generateSQLQuery _).mkString(" UNION ")
85+
val sqlQuery = rules.map(sqlGenerator.generateSQLQuery _).mkString("\tUNION \n")
8686

8787
// parse to SQL node
8888
val sqlNode = Try(planner.parse(sqlQuery))

sansa-inference-common/src/main/scala/net/sansa_stack/inference/rules/plan/SimpleSQLGenerator.scala

Lines changed: 176 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2,48 +2,207 @@ package net.sansa_stack.inference.rules.plan
22

33
import scala.collection.mutable
44

5-
import org.apache.jena.graph.Node
5+
import org.apache.jena.graph.{Node, Triple}
66
import org.apache.jena.reasoner.rulesys.Rule
77

8+
import net.sansa_stack.inference.data.{SQLSchema, SQLSchemaDefault}
89
import net.sansa_stack.inference.utils.RuleUtils.RuleExtension
910
import net.sansa_stack.inference.utils.TripleUtils._
10-
import net.sansa_stack.inference.utils.{Logging, RuleUtils}
11+
import net.sansa_stack.inference.utils.{Logging, RuleUtils, TripleUtils}
1112

1213
/**
14+
* A simple implementation of a SQL generator:
15+
* Joins are generated for common triple patterns in the body.
16+
* Projection variables are based on the head.
1317
* @author Lorenz Buehmann
1418
*/
15-
class SimpleSQLGenerator extends SQLGenerator with Logging {
19+
class SimpleSQLGenerator(val sqlSchema: SQLSchema = SQLSchemaDefault) extends SQLGenerator with Logging {
20+
21+
val aliases = new mutable.HashMap[Triple, String]()
22+
var idx = 0
1623

1724
def generateSQLQuery(rule: Rule): String = {
1825
info(s"Rule:\n$rule")
1926

27+
reset()
28+
2029
val body = rule.bodyTriplePatterns().map(tp => tp.toTriple).toSet
30+
val head = rule.headTriplePatterns().head.asTriple()
31+
32+
var sql = "SELECT "
33+
34+
sql += projectionPart(body, head) + "\n"
35+
36+
sql += fromPart(body) + "\n"
37+
38+
sql += wherePart(body) + "\n"
39+
40+
info(s"SQL Query:\n$sql")
41+
42+
sql
43+
}
2144

22-
val visited = mutable.Set[org.apache.jena.graph.Triple]()
45+
private def reset(): Unit = {
46+
aliases.clear()
47+
idx = 0
48+
}
2349

24-
// process(body.head, body, visited)
50+
private def determineJoins(triples: Set[Triple]): Set[Join] = {
2551

2652
// group triple patterns by var
27-
val map = new mutable.HashMap[Node, collection.mutable.Set[org.apache.jena.graph.Triple]]() with mutable.MultiMap[Node, org.apache.jena.graph.Triple]
28-
body.foreach { tp =>
53+
val var2TPs = new mutable.HashMap[Node, collection.mutable.Set[org.apache.jena.graph.Triple]]() with mutable.MultiMap[Node, org.apache.jena.graph.Triple]
54+
triples.foreach { tp =>
2955
val vars = RuleUtils.varsOf(tp)
3056
vars.foreach { v =>
31-
map.addBinding(v, tp)
57+
var2TPs.addBinding(v, tp)
3258
}
3359
}
3460

35-
val joins = new mutable.HashSet[Join]
36-
37-
map.foreach{e =>
61+
val joins = var2TPs.flatMap{e =>
3862
val v = e._1
39-
val tps = e._2.toList.sortBy(_.toString).combinations(2).foreach(c =>
40-
joins.add(new Join(c(0), c(1), v))
41-
)
63+
64+
e._2.toList.sortBy(_.toString).combinations(2).map(c => new Join(c(0), c(1), v))
65+
}.toSet
66+
67+
joins
68+
}
69+
70+
private def projectionPart(body: Set[Triple], head: Triple): String = {
71+
var sql = ""
72+
73+
val requiredVars = TripleUtils.nodes(head)
74+
75+
val expressions = mutable.ArrayBuffer[String]()
76+
77+
// expressions += (if(target.getSubject.isVariable) expressionFor(target.getSubject, target) else target.getSubject.toString)
78+
// expressions += (if(target.getPredicate.isVariable) expressionFor(target.getPredicate, target) else target.getPredicate.toString)
79+
// expressions += (if(target.getObject.isVariable) expressionFor(target.getObject, target) else target.getObject.toString)
80+
81+
requiredVars.foreach{ v =>
82+
if (v.isVariable) {
83+
var done = false
84+
85+
for(tp <- body; if !done) {
86+
val expr = expressionFor(v, tp)
87+
88+
if(expr != "NULL") {
89+
expressions += expr
90+
done = true
91+
}
92+
}
93+
} else {
94+
expressions += "'" + v.toString + "'"
95+
}
96+
}
97+
98+
sql += expressions.mkString(", ")
99+
100+
sql
101+
}
102+
103+
private def fromPart(body: Set[Triple]): String = {
104+
val joins = determineJoins(body)
105+
106+
var sql = " FROM "
107+
108+
// convert to list of pairs (1,2), (2,3), (3,4)
109+
val list = body.toList.sliding(2).collect { case List(a, b) => (a, b) }.toList
110+
111+
val pair = list.head
112+
val tp1 = pair._1
113+
val tp2 = pair._2
114+
sql += fromPart(tp1) + " INNER JOIN " + fromPart(tp2) + " ON " + joinExpressionFor(joinsFor(tp1, tp2, joins)) + " "
115+
116+
for (i <- 1 until list.length) {
117+
val pair = list(i)
118+
val tp1 = pair._1
119+
val tp2 = pair._2
120+
sql += " INNER JOIN " + fromPart(tp2) + " ON " + joinExpressionFor(joinsFor(tp1, tp2, joins)) + " "
121+
}
122+
123+
124+
// sql += triplePatterns.map(tp => fromPart(tp)).mkString(" INNER JOIN ")
125+
// sql += " ON " + joins.map(join => joinExpressionFor(join)).mkString(" AND ")
126+
sql
127+
}
128+
129+
private def joinsFor(tp1: Triple, tp2: Triple, joins: Set[Join]): Join = {
130+
joins.filter(join => (join.tp1 == tp1 || join.tp2 == tp1) && (join.tp1 == tp2 || join.tp2 == tp2)).head
131+
}
132+
133+
private def wherePart(body: Set[Triple]): String = {
134+
var sql = " WHERE "
135+
val expressions = mutable.ArrayBuffer[String]()
136+
137+
expressions ++= body.flatMap(tp => whereParts(tp))
138+
// expressions ++= joins.map(join => joinExpressionFor(join))
139+
140+
sql += expressions.mkString(" AND ")
141+
142+
sql
143+
}
144+
145+
private def uniqueAliasFor(tp: Triple): String = {
146+
aliases.get(tp) match {
147+
case Some(alias) => alias
148+
case _ =>
149+
val alias = "rel" + idx
150+
aliases += tp -> alias
151+
idx += 1
152+
alias
153+
}
154+
}
155+
156+
private def joinExpressionFor(join: Join): String = {
157+
expressionFor(join.joinVar, join.tp1) + "=" + expressionFor(join.joinVar, join.tp2)
158+
}
159+
160+
private def fromPart(tp: Triple): String = {
161+
tableName(tp)
162+
}
163+
164+
private def expressionFor(variable: Node, tp: Triple): String = {
165+
if (tp.subjectMatches(variable)) {
166+
subjectColumnName(tp)
167+
} else if (tp.predicateMatches(variable)) {
168+
predicateColumnName(tp)
169+
} else if (tp.objectMatches(variable)) {
170+
objectColumnName(tp)
171+
} else {
172+
"NULL"
173+
}
174+
}
175+
176+
private def whereParts(tp: Triple): mutable.Set[String] = {
177+
val res = mutable.Set[String]()
178+
179+
if (!tp.getSubject.isVariable) {
180+
res += subjectColumnName(tp) + "='" + tp.getSubject + "'"
181+
}
182+
183+
if (!tp.getPredicate.isVariable) {
184+
res += predicateColumnName(tp) + "='" + tp.getPredicate + "'"
42185
}
43186

44-
val sqlQuery = new Plan(body, rule.headTriplePatterns().toList.head.asTriple(), joins).toSQL
45-
info(s"SQL Query:\n$sqlQuery")
187+
if (!tp.getObject.isVariable) {
188+
res += objectColumnName(tp) + "='" + tp.getObject + "'"
189+
}
190+
res
191+
}
192+
193+
private def subjectColumnName(tp: Triple): String = {
194+
uniqueAliasFor(tp) + "." + sqlSchema.subjectCol
195+
}
196+
197+
private def predicateColumnName(tp: Triple): String = {
198+
uniqueAliasFor(tp) + "." + sqlSchema.predicateCol
199+
}
200+
201+
private def objectColumnName(tp: Triple): String = {
202+
uniqueAliasFor(tp) + "." + sqlSchema.objectCol
203+
}
46204

47-
sqlQuery
205+
private def tableName(tp: Triple): String = {
206+
sqlSchema.triplesTable + " " + uniqueAliasFor(tp)
48207
}
49208
}

sansa-inference-common/src/main/scala/net/sansa_stack/inference/rules/plan/TriplesTableFactory.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,9 @@ class TriplesTableFactory extends TableFactory[Table] {
3333
val protoRowType = new RelProtoDataType() {
3434
override def apply(a0: RelDataTypeFactory): RelDataType = {
3535
a0.builder()
36-
.add("subject", SqlTypeName.VARCHAR, 200)
37-
.add("predicate", SqlTypeName.VARCHAR, 200)
38-
.add("object", SqlTypeName.VARCHAR, 200)
36+
.add("s", SqlTypeName.VARCHAR, 200)
37+
.add("p", SqlTypeName.VARCHAR, 200)
38+
.add("o", SqlTypeName.VARCHAR, 200)
3939
.build()
4040
}
4141
}

0 commit comments

Comments
 (0)