@@ -2,48 +2,207 @@ package net.sansa_stack.inference.rules.plan
22
33import scala .collection .mutable
44
5- import org .apache .jena .graph .Node
5+ import org .apache .jena .graph .{ Node , Triple }
66import org .apache .jena .reasoner .rulesys .Rule
77
8+ import net .sansa_stack .inference .data .{SQLSchema , SQLSchemaDefault }
89import net .sansa_stack .inference .utils .RuleUtils .RuleExtension
910import net .sansa_stack .inference .utils .TripleUtils ._
10- import net .sansa_stack .inference .utils .{Logging , RuleUtils }
11+ import net .sansa_stack .inference .utils .{Logging , RuleUtils , TripleUtils }
1112
1213/**
14+ * A simple implementation of a SQL generator:
15+ * Joins are generated for common triple patterns in the body.
16+ * Projection variables are based on the head.
1317 * @author Lorenz Buehmann
1418 */
15- class SimpleSQLGenerator extends SQLGenerator with Logging {
19+ class SimpleSQLGenerator (val sqlSchema : SQLSchema = SQLSchemaDefault ) extends SQLGenerator with Logging {
20+
21+ val aliases = new mutable.HashMap [Triple , String ]()
22+ var idx = 0
1623
1724 def generateSQLQuery (rule : Rule ): String = {
1825 info(s " Rule: \n $rule" )
1926
27+ reset()
28+
2029 val body = rule.bodyTriplePatterns().map(tp => tp.toTriple).toSet
30+ val head = rule.headTriplePatterns().head.asTriple()
31+
32+ var sql = " SELECT "
33+
34+ sql += projectionPart(body, head) + " \n "
35+
36+ sql += fromPart(body) + " \n "
37+
38+ sql += wherePart(body) + " \n "
39+
40+ info(s " SQL Query: \n $sql" )
41+
42+ sql
43+ }
2144
22- val visited = mutable.Set [org.apache.jena.graph.Triple ]()
45+ private def reset (): Unit = {
46+ aliases.clear()
47+ idx = 0
48+ }
2349
24- // process(body.head, body, visited)
50+ private def determineJoins ( triples : Set [ Triple ]) : Set [ Join ] = {
2551
2652 // group triple patterns by var
27- val map = new mutable.HashMap [Node , collection.mutable.Set [org.apache.jena.graph.Triple ]]() with mutable.MultiMap [Node , org.apache.jena.graph.Triple ]
28- body .foreach { tp =>
53+ val var2TPs = new mutable.HashMap [Node , collection.mutable.Set [org.apache.jena.graph.Triple ]]() with mutable.MultiMap [Node , org.apache.jena.graph.Triple ]
54+ triples .foreach { tp =>
2955 val vars = RuleUtils .varsOf(tp)
3056 vars.foreach { v =>
31- map .addBinding(v, tp)
57+ var2TPs .addBinding(v, tp)
3258 }
3359 }
3460
35- val joins = new mutable.HashSet [Join ]
36-
37- map.foreach{e =>
61+ val joins = var2TPs.flatMap{e =>
3862 val v = e._1
39- val tps = e._2.toList.sortBy(_.toString).combinations(2 ).foreach(c =>
40- joins.add(new Join (c(0 ), c(1 ), v))
41- )
63+
64+ e._2.toList.sortBy(_.toString).combinations(2 ).map(c => new Join (c(0 ), c(1 ), v))
65+ }.toSet
66+
67+ joins
68+ }
69+
70+ private def projectionPart (body : Set [Triple ], head : Triple ): String = {
71+ var sql = " "
72+
73+ val requiredVars = TripleUtils .nodes(head)
74+
75+ val expressions = mutable.ArrayBuffer [String ]()
76+
77+ // expressions += (if(target.getSubject.isVariable) expressionFor(target.getSubject, target) else target.getSubject.toString)
78+ // expressions += (if(target.getPredicate.isVariable) expressionFor(target.getPredicate, target) else target.getPredicate.toString)
79+ // expressions += (if(target.getObject.isVariable) expressionFor(target.getObject, target) else target.getObject.toString)
80+
81+ requiredVars.foreach{ v =>
82+ if (v.isVariable) {
83+ var done = false
84+
85+ for (tp <- body; if ! done) {
86+ val expr = expressionFor(v, tp)
87+
88+ if (expr != " NULL" ) {
89+ expressions += expr
90+ done = true
91+ }
92+ }
93+ } else {
94+ expressions += " '" + v.toString + " '"
95+ }
96+ }
97+
98+ sql += expressions.mkString(" , " )
99+
100+ sql
101+ }
102+
103+ private def fromPart (body : Set [Triple ]): String = {
104+ val joins = determineJoins(body)
105+
106+ var sql = " FROM "
107+
108+ // convert to list of pairs (1,2), (2,3), (3,4)
109+ val list = body.toList.sliding(2 ).collect { case List (a, b) => (a, b) }.toList
110+
111+ val pair = list.head
112+ val tp1 = pair._1
113+ val tp2 = pair._2
114+ sql += fromPart(tp1) + " INNER JOIN " + fromPart(tp2) + " ON " + joinExpressionFor(joinsFor(tp1, tp2, joins)) + " "
115+
116+ for (i <- 1 until list.length) {
117+ val pair = list(i)
118+ val tp1 = pair._1
119+ val tp2 = pair._2
120+ sql += " INNER JOIN " + fromPart(tp2) + " ON " + joinExpressionFor(joinsFor(tp1, tp2, joins)) + " "
121+ }
122+
123+
124+ // sql += triplePatterns.map(tp => fromPart(tp)).mkString(" INNER JOIN ")
125+ // sql += " ON " + joins.map(join => joinExpressionFor(join)).mkString(" AND ")
126+ sql
127+ }
128+
129+ private def joinsFor (tp1 : Triple , tp2 : Triple , joins : Set [Join ]): Join = {
130+ joins.filter(join => (join.tp1 == tp1 || join.tp2 == tp1) && (join.tp1 == tp2 || join.tp2 == tp2)).head
131+ }
132+
133+ private def wherePart (body : Set [Triple ]): String = {
134+ var sql = " WHERE "
135+ val expressions = mutable.ArrayBuffer [String ]()
136+
137+ expressions ++= body.flatMap(tp => whereParts(tp))
138+ // expressions ++= joins.map(join => joinExpressionFor(join))
139+
140+ sql += expressions.mkString(" AND " )
141+
142+ sql
143+ }
144+
145+ private def uniqueAliasFor (tp : Triple ): String = {
146+ aliases.get(tp) match {
147+ case Some (alias) => alias
148+ case _ =>
149+ val alias = " rel" + idx
150+ aliases += tp -> alias
151+ idx += 1
152+ alias
153+ }
154+ }
155+
156+ private def joinExpressionFor (join : Join ): String = {
157+ expressionFor(join.joinVar, join.tp1) + " =" + expressionFor(join.joinVar, join.tp2)
158+ }
159+
160+ private def fromPart (tp : Triple ): String = {
161+ tableName(tp)
162+ }
163+
164+ private def expressionFor (variable : Node , tp : Triple ): String = {
165+ if (tp.subjectMatches(variable)) {
166+ subjectColumnName(tp)
167+ } else if (tp.predicateMatches(variable)) {
168+ predicateColumnName(tp)
169+ } else if (tp.objectMatches(variable)) {
170+ objectColumnName(tp)
171+ } else {
172+ " NULL"
173+ }
174+ }
175+
176+ private def whereParts (tp : Triple ): mutable.Set [String ] = {
177+ val res = mutable.Set [String ]()
178+
179+ if (! tp.getSubject.isVariable) {
180+ res += subjectColumnName(tp) + " ='" + tp.getSubject + " '"
181+ }
182+
183+ if (! tp.getPredicate.isVariable) {
184+ res += predicateColumnName(tp) + " ='" + tp.getPredicate + " '"
42185 }
43186
44- val sqlQuery = new Plan (body, rule.headTriplePatterns().toList.head.asTriple(), joins).toSQL
45- info(s " SQL Query: \n $sqlQuery" )
187+ if (! tp.getObject.isVariable) {
188+ res += objectColumnName(tp) + " ='" + tp.getObject + " '"
189+ }
190+ res
191+ }
192+
193+ private def subjectColumnName (tp : Triple ): String = {
194+ uniqueAliasFor(tp) + " ." + sqlSchema.subjectCol
195+ }
196+
197+ private def predicateColumnName (tp : Triple ): String = {
198+ uniqueAliasFor(tp) + " ." + sqlSchema.predicateCol
199+ }
200+
201+ private def objectColumnName (tp : Triple ): String = {
202+ uniqueAliasFor(tp) + " ." + sqlSchema.objectCol
203+ }
46204
47- sqlQuery
205+ private def tableName (tp : Triple ): String = {
206+ sqlSchema.triplesTable + " " + uniqueAliasFor(tp)
48207 }
49208}
0 commit comments